Skip to content
Snippets Groups Projects
DaskStreaming.py 10.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • # Copyright 2022-2024 ETSI OSG/SDG TeraFlowSDN (TFS) (https://tfs.etsi.org/)
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #      http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    import logging
    import time
    import json
    from confluent_kafka import Consumer, Producer, KafkaException, KafkaError
    import pandas as pd
    from dask.distributed import Client, LocalCluster
    from common.tools.kafka.Variables import KafkaConfig, KafkaTopic
    
    logging.basicConfig(level=logging.INFO)
    LOGGER = logging.getLogger(__name__)
    
    def SettingKafkaConsumerParams():
        return {'bootstrap.servers'  : KafkaConfig.get_kafka_address(),
                'group.id'           : 'analytics-backend',
                'auto.offset.reset'  : 'latest'}
    
    def GetAggregationMappings(thresholds):
        agg_dict = {}
        for threshold_key in thresholds.keys():
            parts = threshold_key.split('_', 1)
            if len(parts) != 2:
                LOGGER.warning(f"Threshold key '{threshold_key}' does not follow the '<aggregation>_<metricName>' format. Skipping.")
                continue
            aggregation, metric_name = parts
            # Ensure that the aggregation function is valid in pandas
            if aggregation not in ['mean', 'min', 'max', 'first', 'last', 'std']:
                LOGGER.warning(f"Unsupported aggregation '{aggregation}' in threshold key '{threshold_key}'. Skipping.")
                continue
            agg_dict[threshold_key] = ('kpi_value', aggregation)
        return agg_dict
    
    def ApplyThresholds(aggregated_df, thresholds):
        """
        Apply thresholds (TH-Fall and TH-Raise) based on the thresholds dictionary
        on the aggregated DataFrame.
        Args:       aggregated_df (pd.DataFrame): DataFrame with aggregated metrics.
                    thresholds (dict): Thresholds dictionary with keys in the format '<aggregation>_<metricName>'.
        Returns:    pd.DataFrame: DataFrame with additional threshold columns.
        """
        for threshold_key, threshold_values in thresholds.items():
            if threshold_key not in aggregated_df.columns:
                LOGGER.warning(f"Threshold key '{threshold_key}' does not correspond to any aggregation result. Skipping threshold application.")
                continue
            if isinstance(threshold_values, (list, tuple)) and len(threshold_values) == 2:
                fail_th, raise_th = threshold_values
                aggregated_df[f"{threshold_key}_THRESHOLD_FALL"] = aggregated_df[threshold_key] < fail_th
                aggregated_df[f"{threshold_key}_THRESHOLD_RAISE"] = aggregated_df[threshold_key] > raise_th
            else:
                LOGGER.warning(f"Threshold values for '{threshold_key}' are not a list or tuple of length 2. Skipping threshold application.")
        return aggregated_df
    
    def initialize_dask_client():
        """
        Initialize a local Dask cluster and client.
        """
        cluster = LocalCluster(n_workers=2, threads_per_worker=2)
        client = Client(cluster)
        LOGGER.info(f"Dask Client Initialized: {client}")
        return client, cluster
    
    def initialize_kafka_producer():
        return Producer({'bootstrap.servers': KafkaConfig.get_kafka_address()})
    
    def delivery_report(err, msg):
        if err is not None:
            LOGGER.error(f"Message delivery failed: {err}")
        else:
            LOGGER.info(f"Message delivered to {msg.topic()} [{msg.partition()}] at offset {msg.offset()}")
    
    
    def process_batch(batch, agg_mappings, thresholds, key):
    
        """
        Process a batch of data and apply thresholds.
        Args:       batch (list of dict): List of messages from Kafka.
                    agg_mappings (dict): Mapping from threshold key to aggregation function.
                    thresholds (dict): Thresholds dictionary.
        Returns:    list of dict: Processed records ready to be sent to Kafka.
        """
        if not batch:
            LOGGER.info("Empty batch received. Skipping processing.")
            return []
    
    
        df = pd.DataFrame(batch)
    
        LOGGER.info(f"df {df} ")
        df['time_stamp'] = pd.to_datetime(df['time_stamp'], errors='coerce',unit='s')
    
        df.dropna(subset=['time_stamp'], inplace=True)
    
        required_columns = {'time_stamp', 'kpi_id', 'kpi_value'}
        if not required_columns.issubset(df.columns):
            LOGGER.warning(f"Batch contains missing required columns. Required columns: {required_columns}. Skipping batch.")
            return []
        if df.empty:
            LOGGER.info("No data after filtering by KPI IDs. Skipping processing.")
            return []
    
        # Perform aggregations using named aggregation
        try:
            agg_dict = {key: value for key, value in agg_mappings.items()}
    
            df_agg = df.groupby(['window_start']).agg(**agg_dict).reset_index()
    
        except Exception as e:
            LOGGER.error(f"Aggregation error: {e}")
            return []
    
        # Apply thresholds
        df_thresholded = ApplyThresholds(df_agg, thresholds)
    
        df_thresholded['window_start'] = df_thresholded['window_start'].dt.strftime('%Y-%m-%dT%H:%M:%SZ')
        # Convert aggregated DataFrame to list of dicts
        result = df_thresholded.to_dict(orient='records')
        LOGGER.info(f"Processed batch with {len(result)} records after aggregation and thresholding.")
    
        return result
    
    def produce_result(result, producer, destination_topic):
        for record in result:
            try:
                producer.produce(
                    destination_topic,
                    key=str(record.get('kpi_id', '')),
                    value=json.dumps(record),
                    callback=delivery_report
                )
            except KafkaException as e:
                LOGGER.error(f"Failed to produce message: {e}")
        producer.flush()
        LOGGER.info(f"Produced {len(result)} aggregated records to '{destination_topic}'.")
    
    def DaskStreamer(key, kpi_list, thresholds, stop_event,
                    window_size="30s", time_stamp_col="time_stamp"):
        client, cluster = initialize_dask_client()
        consumer_conf   = SettingKafkaConsumerParams()
        consumer        = Consumer(consumer_conf)
        consumer.subscribe([KafkaTopic.VALUE.value])
        producer        = initialize_kafka_producer()
    
        # Parse window_size to seconds
        try:
            window_size_td = pd.to_timedelta(window_size)
            window_size_seconds = window_size_td.total_seconds()
        except Exception as e:
            LOGGER.error(f"Invalid window_size format: {window_size}. Error: {e}")
            window_size_seconds = 30 
        LOGGER.info(f"Batch processing interval set to {window_size_seconds} seconds.")
    
        # Extract aggregation mappings from thresholds
        agg_mappings = GetAggregationMappings(thresholds)
        if not agg_mappings:
            LOGGER.error("No valid aggregation mappings extracted from thresholds. Exiting streamer.")
            consumer.close()
            producer.flush()
            client.close()
            cluster.close()
            return
        try:
            batch = []
            last_batch_time = time.time()
            LOGGER.info("Starting to consume messages...")
    
            while not stop_event.is_set():
                msg = consumer.poll(1.0)
    
                if msg is None:
                    current_time = time.time()
                    if (current_time - last_batch_time) >= window_size_seconds and batch:
                        LOGGER.info("Time-based batch threshold reached. Processing batch.")
                        future = client.submit(process_batch, batch, agg_mappings, thresholds)
                        future.add_done_callback(lambda fut: produce_result(fut.result(), producer, KafkaTopic.ALARMS.value))
                        batch = []
                        last_batch_time = current_time
                    continue
    
                if msg.error():
                    if msg.error().code() == KafkaError._PARTITION_EOF:
                        LOGGER.warning(f"End of partition reached {msg.topic()} [{msg.partition()}] at offset {msg.offset()}")
                    else:
                        LOGGER.error(f"Kafka error: {msg.error()}")
                    continue
    
                try:
                    message_value = json.loads(msg.value().decode('utf-8'))
                except json.JSONDecodeError as e:
                    LOGGER.error(f"JSON decode error: {e}")
                    continue
    
                try:
    
                    message_timestamp = pd.to_datetime(message_value[time_stamp_col], errors='coerce',unit='s')
                    LOGGER.warning(f"message_timestamp: {message_timestamp}. Skipping message.")
    
    
                    if pd.isna(message_timestamp):
                        LOGGER.warning(f"Invalid timestamp in message: {message_value}. Skipping message.")
                        continue
                    window_start = message_timestamp.floor(window_size)
    
                    LOGGER.warning(f"window_start: {window_start}. Skipping message.")
    
                    message_value['window_start'] = window_start
                except Exception as e:
                    LOGGER.error(f"Error processing timestamp: {e}. Skipping message.")
                    continue
    
                if message_value['kpi_id'] not in kpi_list:
                    LOGGER.debug(f"KPI ID '{message_value['kpi_id']}' not in kpi_list. Skipping message.")
                    continue
    
                batch.append(message_value)
    
                current_time = time.time()
                if (current_time - last_batch_time) >= window_size_seconds and batch:
                    LOGGER.info("Time-based batch threshold reached. Processing batch.")
    
                    future = client.submit(process_batch, batch, agg_mappings, thresholds, key)
    
                    future.add_done_callback(lambda fut: produce_result(fut.result(), producer, KafkaTopic.ALARMS.value))
                    batch = []
                    last_batch_time = current_time
    
        except Exception as e:
            LOGGER.exception(f"Error in Dask streaming process: {e}")
        finally:
            # Process any remaining messages in the batch
            if batch:
                LOGGER.info("Processing remaining messages in the batch.")
                future = client.submit(process_batch, batch, agg_mappings, thresholds)
                future.add_done_callback(lambda fut: produce_result(fut.result(), producer, KafkaTopic.ALARMS.value))
            consumer.close()
            producer.flush()
            LOGGER.info("Kafka consumer and producer closed.")
            client.close()
            cluster.close()
            LOGGER.info("Dask client and cluster closed.")