# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function from datetime import datetime from datetime import timedelta import os import numpy as np import onnxruntime as rt import logging import time from common.proto.l3_centralizedattackdetector_pb2 import Empty from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorServicer from common.proto.l3_attackmitigator_pb2 import L3AttackmitigatorOutput from common.proto.monitoring_pb2 import KpiDescriptor from common.proto.kpi_sample_types_pb2 import KpiSampleType from monitoring.client.MonitoringClient import MonitoringClient from common.proto.monitoring_pb2 import Kpi from common.tools.timestamp.Converters import timestamp_utcnow_to_float from common.proto.context_pb2 import Timestamp, SliceId, ConnectionId from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigatorClient import uuid LOGGER = logging.getLogger(__name__) current_dir = os.path.dirname(os.path.abspath(__file__)) MODEL_FILE = os.path.join(current_dir, "ml_model/crypto_auto_features.onnx") class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer): """ Initialize variables, prediction model and clients of components used by CAD """ def __init__(self): LOGGER.info("Creating Centralized Attack Detector Service") self.inference_values = [] self.inference_results = [] self.model = rt.InferenceSession(MODEL_FILE) self.meta = list(self.model.get_modelmeta().custom_metadata_map.values()) self.meta = [int(x) for x in self.meta] self.meta.sort() LOGGER.debug(self.meta) LOGGER.debug("Prueba onnx") self.input_name = self.model.get_inputs()[0].name self.label_name = self.model.get_outputs()[0].name self.prob_name = self.model.get_outputs()[1].name self.monitoring_client = MonitoringClient() self.service_ids = [] self.monitored_kpis = { "l3_security_status": { "kpi_id": None, "description": "L3 - Confidence of the cryptomining detector in the security status in the last time interval of the service {service_id}", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_SECURITY_STATUS_CRYPTO, "service_ids": [], }, "l3_ml_model_confidence": { "kpi_id": None, "description": "L3 - Security status of the service in a time interval of the service {service_id} (“0” if no attack has been detected on the service and “1” if a cryptomining attack has been detected)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_ML_CONFIDENCE, "service_ids": [], }, "l3_unique_attack_conns": { "kpi_id": None, "description": "L3 - Number of attack connections detected in a time interval of the service {service_id} (attacks of the same connection [origin IP, origin port, destination IP and destination port] are only considered once)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_UNIQUE_ATTACK_CONNS, "service_ids": [], }, "l3_unique_compromised_clients": { "kpi_id": None, "description": "L3 - Number of unique compromised clients of the service in a time interval of the service {service_id} (attacks from the same origin IP are only considered once)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_UNIQUE_COMPROMISED_CLIENTS, "service_ids": [], }, "l3_unique_attackers": { "kpi_id": None, "description": "L3 - number of unique attackers of the service in a time interval of the service {service_id} (attacks from the same destination IP are only considered once)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_UNIQUE_ATTACKERS, "service_ids": [], }, } self.attackmitigator_client = l3_attackmitigatorClient() # Environment variables self.CLASSIFICATION_THRESHOLD = os.getenv("CAD_CLASSIFICATION_THRESHOLD", 0.5) self.MONITORED_KPIS_TIME_INTERVAL_AGG = os.getenv("MONITORED_KPIS_TIME_INTERVAL_AGG", 5) # Constants self.NORMAL_CLASS = 0 self.CRYPTO_CLASS = 1 self.kpi_test = None self.time_interval_start = None self.time_interval_end = None # CAD evaluation tests self.cad_inference_times = [] self.cad_num_inference_measurements = 100 # AM evaluation tests self.am_notification_times = [] """ Create a monitored KPI for a specific service and add it to the Monitoring Client -input: + service_id: service ID where the KPI will be monitored + kpi_name: name of the KPI + kpi_description: description of the KPI + kpi_sample_type: KPI sample type of the KPI (it must be defined in the kpi_sample_types.proto file) -output: KPI identifier representing the KPI """ def create_kpi( self, service_id, kpi_name, kpi_description, kpi_sample_type, ): kpidescriptor = KpiDescriptor() kpidescriptor.kpi_description = kpi_description kpidescriptor.service_id.service_uuid.uuid = service_id.service_uuid.uuid kpidescriptor.kpi_sample_type = kpi_sample_type new_kpi = self.monitoring_client.SetKpi(kpidescriptor) LOGGER.info("Created KPI {}".format(kpi_name)) return new_kpi """ Create the monitored KPIs for a specific service, add them to the Monitoring Client and store their identifiers in the monitored_kpis dictionary -input: + service_id: service ID where the KPIs will be monitored -output: None """ def create_kpis(self, service_id, device_id, endpoint_id): LOGGER.info("Creating KPIs for service {}".format(service_id)) # for now, all the KPIs are created for all the services from which requests are received for kpi in self.monitored_kpis: # generate random slice_id slice_id = SliceId() slice_id.slice_uuid.uuid = str(uuid.uuid4()) # generate random connection_id connection_id = ConnectionId() connection_id.connection_uuid.uuid = str(uuid.uuid4()) created_kpi = self.create_kpi( service_id, kpi, self.monitored_kpis[kpi]["description"].format(service_id=service_id.service_uuid.uuid), self.monitored_kpis[kpi]["kpi_sample_type"], ) self.monitored_kpis[kpi]["kpi_id"] = created_kpi.kpi_id self.monitored_kpis[kpi]["service_ids"].append(service_id.service_uuid.uuid) LOGGER.info("Created KPIs for service {}".format(service_id)) def monitor_kpis( self, ): monitor_inference_results = self.inference_results monitor_service_ids = self.service_ids LOGGER.debug("monitor_inference_results: {}".format(len(monitor_inference_results))) LOGGER.debug("monitor_service_ids: {}".format(len(monitor_service_ids))) time_interval = self.MONITORED_KPIS_TIME_INTERVAL_AGG # assign the timestamp of the first inference result to the time_interval_start if self.time_interval_start is None: self.time_interval_start = monitor_inference_results[0]["timestamp"] LOGGER.debug("self.time_interval_start: {}".format(self.time_interval_start)) # self.time_interval_start = datetime.strptime(self.time_interval_start, "%Y-%m-%d %H:%M:%S.%f") # add time_interval to the current time to get the time interval end LOGGER.debug("time_interval: {}".format(time_interval)) LOGGER.debug(timedelta(seconds=time_interval)) self.time_interval_end = self.time_interval_start + timedelta(seconds=time_interval) current_time = datetime.utcnow() LOGGER.debug("current_time: {}".format(current_time)) if current_time >= self.time_interval_end: self.time_interval_start = self.time_interval_end self.time_interval_end = self.time_interval_start + timedelta(seconds=time_interval) LOGGER.debug("time_interval_start: {}".format(self.time_interval_start)) LOGGER.debug("time_interval_end: {}".format(self.time_interval_end)) # delete all inference results that are older than the time_interval_start delete_inference_results = [] for i in range(len(monitor_inference_results)): inference_result_timestamp = monitor_inference_results[i]["timestamp"] if inference_result_timestamp < self.time_interval_start: delete_inference_results.append(monitor_inference_results[i]) if len(delete_inference_results) > 0: monitor_inference_results = [ inference_result for inference_result in monitor_inference_results if inference_result not in delete_inference_results ] LOGGER.debug(f"Cleaned inference results. {len(delete_inference_results)} inference results deleted") # check if there is at least one inference result in monitor_inference_results in the current time_interval num_inference_results_in_time_interval = 0 for i in range(len(monitor_inference_results)): inference_result_timestamp = monitor_inference_results[i]["timestamp"] if ( inference_result_timestamp >= self.time_interval_start and inference_result_timestamp < self.time_interval_end ): num_inference_results_in_time_interval += 1 if num_inference_results_in_time_interval > 0: non_empty_time_interval = True LOGGER.debug( f"Current time interval is not empty (there are {num_inference_results_in_time_interval} inference results" ) else: non_empty_time_interval = False LOGGER.debug("Current time interval is empty. No KPIs will be reported.") if non_empty_time_interval: for service_id in monitor_service_ids: LOGGER.debug("service_id: {}".format(service_id)) # L3 security status kpi_security_status = Kpi() kpi_security_status.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_security_status"]["kpi_id"]) # get the output.tag of the ML model of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable outputs_last_time_interval = [] for i in range(len(monitor_inference_results)): if ( monitor_inference_results[i]["timestamp"] >= self.time_interval_start and monitor_inference_results[i]["timestamp"] < self.time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_security_status"]["service_ids"] ): outputs_last_time_interval.append(monitor_inference_results[i]["output"]["tag"]) LOGGER.debug("outputs_last_time_interval: {}".format(outputs_last_time_interval)) # check if all outputs are 0 all_outputs_zero = True for output in outputs_last_time_interval: if output != self.NORMAL_CLASS: all_outputs_zero = False break if all_outputs_zero: kpi_security_status.kpi_value.int32Val = 0 else: kpi_security_status.kpi_value.int32Val = 1 # L3 ML model confidence kpi_conf = Kpi() kpi_conf.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_ml_model_confidence"]["kpi_id"]) # get the output.confidence of the ML model of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable confidences_normal_last_time_interval = [] confidences_crypto_last_time_interval = [] for i in range(len(monitor_inference_results)): LOGGER.debug("monitor_inference_results[i]: {}".format(monitor_inference_results[i])) if ( monitor_inference_results[i]["timestamp"] >= self.time_interval_start and monitor_inference_results[i]["timestamp"] < self.time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_ml_model_confidence"]["service_ids"] ): if monitor_inference_results[i]["output"]["tag"] == self.NORMAL_CLASS: confidences_normal_last_time_interval.append( monitor_inference_results[i]["output"]["confidence"] ) elif monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: confidences_crypto_last_time_interval.append( monitor_inference_results[i]["output"]["confidence"] ) else: LOGGER.debug("Unknown tag: {}".format(monitor_inference_results[i]["output"]["tag"])) LOGGER.debug("confidences_normal_last_time_interval: {}".format(confidences_normal_last_time_interval)) LOGGER.debug("confidences_crypto_last_time_interval: {}".format(confidences_crypto_last_time_interval)) if kpi_security_status.kpi_value.int32Val == 0: kpi_conf.kpi_value.floatVal = np.mean(confidences_normal_last_time_interval) else: kpi_conf.kpi_value.floatVal = np.mean(confidences_crypto_last_time_interval) # L3 unique attack connections kpi_unique_attack_conns = Kpi() kpi_unique_attack_conns.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_unique_attack_conns"]["kpi_id"]) # get the number of unique attack connections (grouping by origin IP, origin port, destination IP, destination port) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable num_unique_attack_conns_last_time_interval = 0 unique_attack_conns_last_time_interval = [] for i in range(len(monitor_inference_results)): if ( monitor_inference_results[i]["timestamp"] >= self.time_interval_start and monitor_inference_results[i]["timestamp"] < self.time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_unique_attack_conns"]["service_ids"] ): if monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: current_attack_conn = { "ip_o": monitor_inference_results[i]["output"]["ip_o"], "port_o": monitor_inference_results[i]["output"]["port_o"], "ip_d": monitor_inference_results[i]["output"]["ip_d"], "port_d": monitor_inference_results[i]["output"]["port_d"], } is_unique_attack_conn = True for j in range(len(unique_attack_conns_last_time_interval)): if current_attack_conn == unique_attack_conns_last_time_interval[j]: is_unique_attack_conn = False if is_unique_attack_conn: num_unique_attack_conns_last_time_interval += 1 unique_attack_conns_last_time_interval.append(current_attack_conn) kpi_unique_attack_conns.kpi_value.int32Val = num_unique_attack_conns_last_time_interval # L3 unique compromised clients kpi_unique_compromised_clients = Kpi() kpi_unique_compromised_clients.kpi_id.kpi_id.CopyFrom( self.monitored_kpis["l3_unique_compromised_clients"]["kpi_id"] ) # get the number of unique compromised clients (grouping by origin IP) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable num_unique_compromised_clients_last_time_interval = 0 unique_compromised_clients_last_time_interval = [] for i in range(len(monitor_inference_results)): if ( monitor_inference_results[i]["timestamp"] >= self.time_interval_start and monitor_inference_results[i]["timestamp"] < self.time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_unique_compromised_clients"]["service_ids"] ): if monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: if ( monitor_inference_results[i]["output"]["ip_o"] not in unique_compromised_clients_last_time_interval ): unique_compromised_clients_last_time_interval.append( monitor_inference_results[i]["output"]["ip_o"] ) num_unique_compromised_clients_last_time_interval += 1 kpi_unique_compromised_clients.kpi_value.int32Val = num_unique_compromised_clients_last_time_interval # L3 unique attackers kpi_unique_attackers = Kpi() kpi_unique_attackers.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_unique_attackers"]["kpi_id"]) # get the number of unique attackers (grouping by destination ip) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable num_unique_attackers_last_time_interval = 0 unique_attackers_last_time_interval = [] for i in range(len(monitor_inference_results)): if ( monitor_inference_results[i]["timestamp"] >= self.time_interval_start and monitor_inference_results[i]["timestamp"] < self.time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_unique_attackers"]["service_ids"] ): if monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: if ( monitor_inference_results[i]["output"]["ip_d"] not in unique_attackers_last_time_interval ): unique_attackers_last_time_interval.append( monitor_inference_results[i]["output"]["ip_d"] ) num_unique_attackers_last_time_interval += 1 kpi_unique_attackers.kpi_value.int32Val = num_unique_attackers_last_time_interval timestamp = Timestamp() timestamp.timestamp = timestamp_utcnow_to_float() kpi_security_status.timestamp.CopyFrom(timestamp) kpi_conf.timestamp.CopyFrom(timestamp) kpi_unique_attack_conns.timestamp.CopyFrom(timestamp) kpi_unique_compromised_clients.timestamp.CopyFrom(timestamp) kpi_unique_attackers.timestamp.CopyFrom(timestamp) LOGGER.debug("Sending KPIs to monitoring server") LOGGER.debug("kpi_security_status: {}".format(kpi_security_status)) LOGGER.debug("kpi_conf: {}".format(kpi_conf)) LOGGER.debug("kpi_unique_attack_conns: {}".format(kpi_unique_attack_conns)) LOGGER.debug("kpi_unique_compromised_clients: {}".format(kpi_unique_compromised_clients)) LOGGER.debug("kpi_unique_attackers: {}".format(kpi_unique_attackers)) try: self.monitoring_client.IncludeKpi(kpi_security_status) self.monitoring_client.IncludeKpi(kpi_conf) self.monitoring_client.IncludeKpi(kpi_unique_attack_conns) self.monitoring_client.IncludeKpi(kpi_unique_compromised_clients) self.monitoring_client.IncludeKpi(kpi_unique_attackers) except Exception as e: LOGGER.debug("Error sending KPIs to monitoring server: {}".format(e)) LOGGER.debug("KPIs sent to monitoring server") else: LOGGER.debug("No KPIs sent to monitoring server") """ Classify connection as standard traffic or cryptomining attack and return results -input: + request: L3CentralizedattackdetectorMetrics object with connection features information -output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence """ def make_inference(self, request): '''x_data = np.array( [ [ request.c_pkts_all, request.c_ack_cnt, request.c_bytes_uniq, request.c_pkts_data, request.c_bytes_all, request.s_pkts_all, request.s_ack_cnt, request.s_bytes_uniq, request.s_pkts_data, request.s_bytes_all, ] ] )''' x_data = np.array( [ [feature.feature for feature in request.features] ] ) # Print input data shape LOGGER.debug("x_data.shape: {}".format(x_data.shape)) # Get batch size batch_size = x_data.shape[0] # Print batch size LOGGER.debug("batch_size: {}".format(batch_size)) # TEST: Remove later test_batch_size = 1024 # duplicate x_data to test_batch_size x_data = np.repeat(x_data, test_batch_size, axis=0) LOGGER.debug("x_data.shape: {}".format(x_data.shape)) inference_time_start = time.perf_counter() # Perform inference predictions = self.model.run([self.prob_name], {self.input_name: x_data.astype(np.float32)})[0] inference_time_end = time.perf_counter() # Measure inference time inference_time = inference_time_end - inference_time_start self.cad_inference_times.append(inference_time) if len(self.cad_inference_times) > self.cad_num_inference_measurements: inference_times_np_array = np.array(self.cad_inference_times) np.save(f"inference_times_{test_batch_size}.npy", inference_times_np_array) avg_inference_time = np.mean(inference_times_np_array) max_inference_time = np.max(inference_times_np_array) min_inference_time = np.min(inference_times_np_array) std_inference_time = np.std(inference_times_np_array) median_inference_time = np.median(inference_times_np_array) LOGGER.debug("Average inference time: {}".format(avg_inference_time)) LOGGER.debug("Max inference time: {}".format(max_inference_time)) LOGGER.debug("Min inference time: {}".format(min_inference_time)) LOGGER.debug("Standard deviation inference time: {}".format(std_inference_time)) LOGGER.debug("Median inference time: {}".format(median_inference_time)) with open(f"inference_times_stats_{batch_size}.txt", "w") as f: f.write("Average inference time: {}\n".format(avg_inference_time)) f.write("Max inference time: {}\n".format(max_inference_time)) f.write("Min inference time: {}\n".format(min_inference_time)) f.write("Standard deviation inference time: {}\n".format(std_inference_time)) f.write("Median inference time: {}\n".format(median_inference_time)) # Gather the predicted class, the probability of that class and other relevant information required to block the attack output_message = { "confidence": None, "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"), "ip_o": request.connection_metadata.ip_o, "ip_d": request.connection_metadata.ip_d, "tag_name": None, "tag": None, "flow_id": request.connection_metadata.flow_id, "protocol": request.connection_metadata.protocol, "port_o": request.connection_metadata.port_o, "port_d": request.connection_metadata.port_d, "ml_id": "RandomForest", "service_id": request.connection_metadata.service_id, "endpoint_id": request.connection_metadata.endpoint_id, "time_start": request.connection_metadata.time_start, "time_end": request.connection_metadata.time_end, } if predictions[0][1] >= self.CLASSIFICATION_THRESHOLD: output_message["confidence"] = predictions[0][1] output_message["tag_name"] = "Crypto" output_message["tag"] = self.CRYPTO_CLASS else: output_message["confidence"] = predictions[0][0] output_message["tag_name"] = "Normal" output_message["tag"] = self.NORMAL_CLASS return output_message """ Receive features from Attack Mitigator, predict attack and communicate with Attack Mitigator -input: + request: L3CentralizedattackdetectorMetrics object with connection features information -output: Empty object with a message about the execution of the function """ def SendInput(self, request, context): # Perform inference with the data sent in the request logging.info("Performing inference...") cryptomining_detector_output = self.make_inference(request) logging.info("Inference performed correctly") self.inference_results.append({"output": cryptomining_detector_output, "timestamp": datetime.now()}) service_id = request.connection_metadata.service_id device_id = request.connection_metadata.endpoint_id.device_id endpoint_id = request.connection_metadata.endpoint_id # Check if a request of a new service has been received and, if so, create the monitored KPIs for that service if service_id not in self.service_ids: self.create_kpis(service_id, device_id, endpoint_id) self.service_ids.append(service_id) self.monitor_kpis() LOGGER.debug("cryptomining_detector_output: {}".format(cryptomining_detector_output)) # Only notify Attack Mitigator when a cryptomining connection has been detected if cryptomining_detector_output["tag_name"] == "Crypto": notification_time_start = time.perf_counter() logging.info("Crypto attack detected") # Notify the Attack Mitigator component about the attack logging.info( "Notifying the Attack Mitigator component about the attack in order to block the connection..." ) try: logging.info("Sending the connection information to the Attack Mitigator component...") message = L3AttackmitigatorOutput(**cryptomining_detector_output) response = self.attackmitigator_client.SendOutput(message) notification_time_end = time.perf_counter() self.am_notification_times.append(notification_time_end - notification_time_start) LOGGER.debug(f"am_notification_times length: {len(self.am_notification_times)}") if len(self.am_notification_times) > 100: am_notification_times_np_array = np.array(self.am_notification_times) np.save("am_notification_times.npy", am_notification_times_np_array) avg_notification_time = np.mean(am_notification_times_np_array) max_notification_time = np.max(am_notification_times_np_array) min_notification_time = np.min(am_notification_times_np_array) std_notification_time = np.std(am_notification_times_np_array) median_notification_time = np.median(am_notification_times_np_array) LOGGER.debug("Average notification time: {}".format(avg_notification_time)) LOGGER.debug("Max notification time: {}".format(max_notification_time)) LOGGER.debug("Min notification time: {}".format(min_notification_time)) LOGGER.debug("Std notification time: {}".format(std_notification_time)) LOGGER.debug("Median notification time: {}".format(median_notification_time)) with open("am_notification_times_stats.txt", "w") as f: f.write("Average notification time: {}\n".format(avg_notification_time)) f.write("Max notification time: {}\n".format(max_notification_time)) f.write("Min notification time: {}\n".format(min_notification_time)) f.write("Std notification time: {}\n".format(std_notification_time)) f.write("Median notification time: {}\n".format(median_notification_time)) # logging.info("Attack Mitigator notified and received response: ", response.message) # FIX No message received logging.info("Attack Mitigator notified") return Empty(message="OK, information received and mitigator notified abou the attack") except Exception as e: logging.error("Error notifying the Attack Mitigator component about the attack: ", e) logging.error("Couldn't find l3_attackmitigator") return Empty(message="Attack Mitigator not found") else: logging.info("No attack detected") return Empty(message="Ok, information received (no attack detected)")