diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py index 2de768810c06f48e2ffa282dd4e1308dc30554b0..c56f117f383af69d90189653d5e45344d58aecfd 100644 --- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py @@ -20,6 +20,7 @@ import grpc import numpy as np import onnxruntime as rt import logging +from time import sleep from common.proto.l3_centralizedattackdetector_pb2 import Empty from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorServicer @@ -38,74 +39,301 @@ from common.proto.context_pb2 import Timestamp from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigatorClient +from multiprocessing import Process, Queue + LOGGER = logging.getLogger(__name__) current_dir = os.path.dirname(os.path.abspath(__file__)) MODEL_FILE = os.path.join(current_dir, "ml_model/crypto_5g_rf_spider_features.onnx") -classification_threshold = os.getenv("CAD_CLASSIFICATION_THRESHOLD", 0.5) - class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer): """ Initialize variables, prediction model and clients of components used by CAD - """ + """ + def __init__(self): LOGGER.info("Creating Centralized Attack Detector Service") - self.inference_values = [] + self.inference_values = Queue() + self.inference_results = Queue() self.model = rt.InferenceSession(MODEL_FILE) self.input_name = self.model.get_inputs()[0].name self.label_name = self.model.get_outputs()[0].name self.prob_name = self.model.get_outputs()[1].name self.monitoring_client = MonitoringClient() - self.predicted_class_kpi_id = None - self.class_probability_kpi_id = None - + self.service_ids = [] + self.monitored_kpis = { + "l3_security_status": { + "kpi_id": None, + "description": "L3 - Confidence of the cryptomining detector in the security status in the last time interval of the service {service_id}", + "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_UNKNOWN, # TODO: change this to KPI_L3_SECURITY_STATUS and add it to kpi_sample_types.proto + "service_ids": [], + }, + "l3_ml_model_confidence": { + "kpi_id": None, + "description": "L3 - Security status of the service in a time interval of the service {service_id} (“0†if no attack has been detected on the service and “1†if a cryptomining attack has been detected)", + "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_UNKNOWN, # TODO: change this to KPI_L3_ML_CONFIDENCE and add it to kpi_sample_types.proto + "service_ids": [], + }, + "l3_unique_attack_conns": { + "kpi_id": None, + "description": "L3 - Number of attack connections detected in a time interval of the service {service_id} (attacks of the same connection [origin IP, origin port, destination IP and destination port] are only considered once)", + "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_UNKNOWN, # TODO: change this to KPI_UNIQUE_ATTACK_CONNS and add it to kpi_sample_types.proto + "service_ids": [], + }, + "l3_unique_compromised_clients": { + "kpi_id": None, + "description": "L3 - Number of unique compromised clients of the service in a time interval of the service {service_id} (attacks from the same origin IP are only considered once)", + "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_UNKNOWN, # TODO: change this to KPI_UNIQUE_COMPROMISED_CLIENTS and add it to kpi_sample_types.proto + "service_ids": [], + }, + "l3_unique_attackers": { + "kpi_id": None, + "description": "L3 - number of unique attackers of the service in a time interval of the service {service_id} (attacks from the same destination IP are only considered once)", + "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_UNKNOWN, # TODO: change this to KPI_UNIQUE_ATTACKERS and add it to kpi_sample_types.proto + "service_ids": [], + }, + } self.attackmitigator_client = l3_attackmitigatorClient() - """ - Create the Cryptomining Detector Predicted Class KPI for a service and add it to the Monitoring Client - -input: - + client: Monitoring Client object where the KPI will be tracked - + service_id: service ID where the KPI will be created - -output: KPI identifier representing the Cryptomining Detector Predicted Class KPI - """ - def create_predicted_class_kpi(self, client: MonitoringClient, service_id): - kpi_description: KpiDescriptor = KpiDescriptor() - kpi_description.kpi_description = "Cryptomining Detector Predicted Class (service: {})".format(service_id) - kpi_description.service_id.service_uuid.uuid = service_id.service_uuid.uuid - kpi_description.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN - new_kpi = client.SetKpi(kpi_description) + # Environment variables + self.CLASSIFICATION_THRESHOLD = os.getenv("CAD_CLASSIFICATION_THRESHOLD", 0.5) + self.MONITORED_KPIS_TIME_INTERVAL_AGG = os.getenv("MONITORED_KPIS_TIME_INTERVAL_AGG", 5) - LOGGER.info("Created Predicted Class KPI {}...".format(new_kpi.kpi_id)) + # Constants + self.NORMAL_CLASS = 0 + self.CRYPTO_CLASS = 1 - return new_kpi + # start monitoring process + self.monitoring_process = Process( + target=self.monitoring_process, args=(self.inference_values, self.inference_results) + ) + self.monitoring_process.start() """ - Create the Cryptomining Detector Prediction KPI for a service and add it to the Monitoring Client + Create a monitored KPI for a specific service and add it to the Monitoring Client -input: + client: Monitoring Client object where the KPI will be tracked - + service_id: service ID where the KPI will be created - -output: KPI identifier representing the Cryptomining Detector Prediction KPI + + service_id: service ID where the KPI will be monitored + -output: KPI identifier representing the KPI """ - def create_class_prob_kpi(self, client: MonitoringClient, service_id): + + def create_kpi(self, client: MonitoringClient, service_id, kpi_description, kpi_sample_type): kpi_description: KpiDescriptor = KpiDescriptor() - kpi_description.kpi_description = "Cryptomining Detector Prediction (service: {})".format(service_id) + kpi_description.kpi_description = kpi_description kpi_description.service_id.service_uuid.uuid = service_id.service_uuid.uuid - kpi_description.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN + kpi_description.kpi_sample_type = kpi_sample_type new_kpi = client.SetKpi(kpi_description) - LOGGER.info("Created Class Probability KPI {}...".format(new_kpi.kpi_id)) + LOGGER.info("Created KPI {}...".format(kpi_sample_type)) return new_kpi + """ + Create the monitored KPIs for a specific service, add them to the Monitoring Client and store their identifiers in the monitored_kpis dictionary + -input: + + service_id: service ID where the KPIs will be monitored + -output: None + """ + + def create_kpis(self, service_id): + # for now, all the KPIs are created for all the services from which requests are received + for kpi in self.monitored_kpis: + created_kpi = self.create_kpi( + self.monitoring_client, + service_id, + self.monitored_kpis[kpi]["description"], + self.monitored_kpis[kpi]["kpi_sample_type"], + ) + self.monitored_kpis[kpi]["kpi_id"] = created_kpi.kpi_id + self.monitored_kpis[kpi]["service_ids"].append(service_id.service_uuid.uuid) + + def monitor_kpis(self): + while True: + # get all information from the inference_values queue + monitor_inference_values = [] + + for i in range(self.inference_values.qsize()): + monitor_inference_values.append(self.inference_values.get()) + + # get all information from the inference_results queue + monitor_inference_results = [] + + for i in range(self.inference_results.qsize()): + monitor_inference_results.append(self.inference_results.get()) + + for service_id in self.service_ids: + time_interval = self.MONITORED_KPIS_TIME_INTERVAL_AGG + time_interval_start = datetime.utcnow() + time_interval_end = time_interval_start + time_interval + + # L3 security status + kpi_security_status = Kpi() + kpi_security_status.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_security_status"]["kpi_id"].kpi_id) + + # get the output.tag of the ML model of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable + outputs_last_time_interval = [] + + for i in range(self.monitor_inference_results): + if ( + self.monitor_inference_results[i]["timestamp"] >= time_interval_start + and self.monitor_inference_results[i]["timestamp"] <= time_interval_end + and self.monitor_inference_results[i]["service_id"] == service_id + and service_id in self.monitored_kpis["l3_security_status"]["service_ids"] + ): + outputs_last_time_interval.append(self.monitor_inference_results[i]["output"]["tag"]) + + kpi_security_status.kpi_value.intVal = ( + 0 if np.all(outputs_last_time_interval == self.NORMAL_CLASS) else 1 + ) + + # L3 ML model confidence + kpi_conf = Kpi() + kpi_conf.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_ml_model_confidence"]["kpi_id"].kpi_id) + + # get the output.confidence of the ML model of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable + confidences_normal_last_time_interval = [] + confidences_crypto_last_time_interval = [] + + for i in range(self.monitor_inference_results): + if ( + self.monitor_inference_results[i]["timestamp"] >= time_interval_start + and self.monitor_inference_results[i]["timestamp"] <= time_interval_end + and self.monitor_inference_results[i]["service_id"] == service_id + and service_id in self.monitored_kpis["l3_security_status"]["service_ids"] + ): + if self.monitor_inference_results[i]["output"]["tag"] == self.NORMAL_CLASS: + confidences_normal_last_time_interval.append( + self.monitor_inference_results[i]["output"]["confidence"] + ) + elif self.monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: + confidences_crypto_last_time_interval.append( + self.monitor_inference_results[i]["output"]["confidence"] + ) + + kpi_conf.kpi_value.intVal = ( + np.mean(confidences_crypto_last_time_interval) + if np.all(outputs_last_time_interval == self.CRYPTO_CLASS) + else np.mean(confidences_normal_last_time_interval) + ) + + # L3 unique attack connections + kpi_unique_attack_conns = Kpi() + kpi_unique_attack_conns.kpi_id.kpi_id.CopyFrom( + self.monitored_kpis["l3_unique_attack_conns"]["kpi_id"].kpi_id + ) + + # get the number of unique attack connections (grouping by origin IP, origin port, destination IP, destination port) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable + num_unique_attack_conns_last_time_interval = 0 + unique_attack_conns_last_time_interval = [] + + for i in range(self.monitor_inference_results): + if ( + self.monitor_inference_results[i]["timestamp"] >= time_interval_start + and self.monitor_inference_results[i]["timestamp"] <= time_interval_end + and self.monitor_inference_results[i]["service_id"] == service_id + and service_id in self.monitored_kpis["l3_security_status"]["service_ids"] + ): + if self.monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: + current_attack_conn = { + "ip_o": self.monitor_inference_results[i]["input"]["src_ip"], + "port_o": self.monitor_inference_results[i]["input"]["src_port"], + "ip_d": self.monitor_inference_results[i]["input"]["dst_ip"], + "port_d": self.monitor_inference_results[i]["input"]["dst_port"], + } + + for j in range(unique_attack_conns_last_time_interval): + if current_attack_conn == unique_attack_conns_last_time_interval[j]: + break + + num_unique_attack_conns_last_time_interval += 1 + unique_attack_conns_last_time_interval.append(current_attack_conn) + + kpi_unique_attack_conns.kpi_value.intVal = num_unique_attack_conns_last_time_interval + + # L3 unique compromised clients + kpi_unique_compromised_clients = Kpi() + kpi_unique_compromised_clients.kpi_id.kpi_id.CopyFrom( + self.monitored_kpis["l3_unique_attack_conns"]["kpi_id"].kpi_id + ) + + # get the number of unique compromised clients (grouping by origin IP) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable + num_unique_compromised_clients_last_time_interval = 0 + unique_compromised_clients_last_time_interval = [] + + for i in range(self.monitor_inference_results): + if ( + self.monitor_inference_results[i]["timestamp"] >= time_interval_start + and self.monitor_inference_results[i]["timestamp"] <= time_interval_end + and self.monitor_inference_results[i]["service_id"] == service_id + and service_id in self.monitored_kpis["l3_security_status"]["service_ids"] + ): + if self.monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: + if ( + self.monitor_inference_results[i]["output"]["ip_o"] + not in unique_compromised_clients_last_time_interval + ): + unique_compromised_clients_last_time_interval.append( + self.monitor_inference_results[i]["output"]["ip_o"] + ) + num_unique_compromised_clients_last_time_interval += 1 + + kpi_unique_compromised_clients.kpi_value.intVal = num_unique_compromised_clients_last_time_interval + + # L3 unique attackers + kpi_unique_attackers = Kpi() + kpi_unique_attackers.kpi_id.kpi_id.CopyFrom( + self.monitored_kpis["l3_unique_attack_conns"]["kpi_id"].kpi_id + ) + + # get the number of unique attackers (grouping by destination ip) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable + num_unique_attackers_last_time_interval = 0 + unique_attackers_last_time_interval = [] + + for i in range(self.monitor_inference_results): + if ( + self.monitor_inference_results[i]["timestamp"] >= time_interval_start + and self.monitor_inference_results[i]["timestamp"] <= time_interval_end + and self.monitor_inference_results[i]["service_id"] == service_id + and service_id in self.monitored_kpis["l3_security_status"]["service_ids"] + ): + if self.monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: + if ( + self.monitor_inference_results[i]["output"]["ip_d"] + not in unique_attackers_last_time_interval + ): + unique_attackers_last_time_interval.append( + self.monitor_inference_results[i]["output"]["ip_d"] + ) + num_unique_attackers_last_time_interval += 1 + + kpi_unique_attackers.kpi_value.intVal = num_unique_attackers_last_time_interval + + timestamp = Timestamp() + timestamp.timestamp = timestamp_utcnow_to_float() + + kpi_security_status.timestamp.CopyFrom(timestamp) + kpi_conf.timestamp.CopyFrom(timestamp) + kpi_unique_attack_conns.timestamp.CopyFrom(timestamp) + kpi_unique_compromised_clients.timestamp.CopyFrom(timestamp) + kpi_unique_attackers.timestamp.CopyFrom(timestamp) + + self.monitoring_client.IncludeKpi(kpi_security_status) + self.monitoring_client.IncludeKpi(kpi_conf) + self.monitoring_client.IncludeKpi(kpi_unique_attack_conns) + self.monitoring_client.IncludeKpi(kpi_unique_compromised_clients) + self.monitoring_client.IncludeKpi(kpi_unique_attackers) + + sleep(self.MONITORED_KPIS_TIME_INTERVAL_AGG) + """ Classify connection as standard traffic or cryptomining attack and return results -input: + request: L3CentralizedattackdetectorMetrics object with connection features information -output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence """ + def make_inference(self, request): x_data = np.array( [ @@ -145,16 +373,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto "time_end": request.time_end, } - if predictions[0][1] >= classification_threshold: + if predictions[0][1] >= self.CLASSIFICATION_THRESHOLD: output_message["confidence"] = predictions[0][1] output_message["tag_name"] = "Crypto" - output_message["tag"] = 1 + output_message["tag"] = self.CRYPTO_CLASS else: output_message["confidence"] = predictions[0][0] output_message["tag_name"] = "Normal" - output_message["tag"] = 0 + output_message["tag"] = self.NORMAL_CLASS - return L3AttackmitigatorOutput(**output_message) + return output_message """ Receive features from Attack Mitigator, predict attack and communicate with Attack Mitigator @@ -162,50 +390,28 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto + request: L3CentralizedattackdetectorMetrics object with connection features information -output: Empty object with a message about the execution of the function """ + def SendInput(self, request, context): # Store the data sent in the request - self.inference_values.append(request) + self.inference_values.put({"request": request, "timestamp": datetime.now()}) # Perform inference with the data sent in the request logging.info("Performing inference...") - output = self.make_inference(request) + cryptomining_detector_output = self.make_inference(request) logging.info("Inference performed correctly") - # Include monitored KPIs values - service_id = request.service_id + # Store the results of the inference that will be later used to monitor the KPIs + self.inference_results.put({"output": cryptomining_detector_output, "timestamp": datetime.now()}) - if self.predicted_class_kpi_id is None: - self.predicted_class_kpi_id = self.create_predicted_class_kpi(self.monitoring_client, service_id) - - if self.class_probability_kpi_id is None: - self.class_probability_kpi_id = self.create_class_prob_kpi(self.monitoring_client, service_id) - - # Packet Aggregation Features -> DAD -> CAD -> ML -> (2 Instantaneous Value: higher class probability, predicted class) -> Monitoring - # In addition, two counters: - # Counter 1: Total number of crypto attack connections - # Counter 2: Rate of crypto attack connections with respect to the total number of connections - - # Predicted class KPI - kpi_class = Kpi() - kpi_class.kpi_id.kpi_id.CopyFrom(self.predicted_class_kpi_id.kpi_id) - kpi_class.kpi_value.int32Val = 1 if output.tag_name == "Crypto" else 0 - - # Class probability KPI - kpi_prob = Kpi() - kpi_prob.kpi_id.kpi_id.CopyFrom(self.class_probability_kpi_id.kpi_id) - kpi_prob.kpi_value.floatVal = output.confidence - - timestamp = Timestamp() - timestamp.timestamp = timestamp_utcnow_to_float() + service_id = request.service_id - kpi_class.timestamp.CopyFrom(timestamp) - kpi_prob.timestamp.CopyFrom(kpi_class.timestamp) + # Check if a request of a new service has been received and, if so, create the monitored KPIs for that service + if service_id not in self.service_ids: + self.create_kpis(service_id) + self.service_ids.append(service_id) - self.monitoring_client.IncludeKpi(kpi_class) - self.monitoring_client.IncludeKpi(kpi_prob) - # Only notify Attack Mitigator when a cryptomining connection has been detected - if output.tag_name == "Crypto": + if cryptomining_detector_output["tag_name"] == "Crypto": logging.info("Crypto attack detected") # Notify the Attack Mitigator component about the attack @@ -215,7 +421,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto try: logging.info("Sending the connection information to the Attack Mitigator component...") - response = self.attackmitigator_client.SendOutput(output) + message = L3AttackmitigatorOutput(**cryptomining_detector_output) + response = self.attackmitigator_client.SendOutput(message) logging.info( "Attack Mitigator notified and received response: ", response.message ) # FIX No message received @@ -231,6 +438,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto return Empty(message="Ok, information received (no attack detected)") + """ def GetOutput(self, request, context): logging.info("Returning inference output...")