diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py index d84e71ce63dee0e61109487b393605b24ad0f32c..26874a3264e995d71a6caa9c0a6cad00059597ed 100644 --- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py @@ -64,8 +64,6 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto def __init__(self): LOGGER.info("Creating Centralized Attack Detector Service") - # self.inference_values = Queue() - # self.inference_results = Queue() self.inference_values = [] self.inference_results = [] self.model = rt.InferenceSession(MODEL_FILE) @@ -74,7 +72,6 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto self.prob_name = self.model.get_outputs()[1].name self.monitoring_client = MonitoringClient() self.service_ids = [] - # self.monitored_service_ids = Queue() self.monitored_kpis = { "l3_security_status": { "kpi_id": None, @@ -108,8 +105,6 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto }, } self.attackmitigator_client = l3_attackmitigatorClient() - # self.context_client = ContextClient() - # self.context_id = "admin" # Environment variables self.CLASSIFICATION_THRESHOLD = os.getenv("CAD_CLASSIFICATION_THRESHOLD", 0.5) @@ -119,16 +114,6 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto self.NORMAL_CLASS = 0 self.CRYPTO_CLASS = 1 - # # start monitoring process - # self.monitoring_process = Process( - # target=self.monitor_kpis, - # args=( - # self.monitored_service_ids, - # self.inference_results, - # ), - # ) - # self.monitoring_process.start() - self.kpi_test = None self.time_interval_start = None self.time_interval_end = None @@ -157,10 +142,10 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto kpidescriptor = KpiDescriptor() kpidescriptor.kpi_description = kpi_description kpidescriptor.service_id.service_uuid.uuid = service_id.service_uuid.uuid - kpidescriptor.device_id.device_uuid.uuid = device_id.device_uuid.uuid - kpidescriptor.endpoint_id.endpoint_uuid.uuid = endpoint_id.endpoint_uuid.uuid - kpidescriptor.slice_id.slice_uuid.uuid = slice_id.slice_uuid.uuid - kpidescriptor.connection_id.connection_uuid.uuid = connection_id.connection_uuid.uuid + # kpidescriptor.device_id.device_uuid.uuid = device_id.device_uuid.uuid + # kpidescriptor.endpoint_id.endpoint_uuid.uuid = endpoint_id.endpoint_uuid.uuid + # kpidescriptor.slice_id.slice_uuid.uuid = slice_id.slice_uuid.uuid + # kpidescriptor.connection_id.connection_uuid.uuid = connection_id.connection_uuid.uuid kpidescriptor.kpi_sample_type = kpi_sample_type new_kpi = self.monitoring_client.SetKpi(kpidescriptor) @@ -204,292 +189,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto LOGGER.info("Created KPIs for service {}".format(service_id)) - # LOGGER.info("Starting monitoring process") - # self.monitoring_process.start() - - def monitor_kpis(self, service_ids, inference_results): - self.monitoring_client_test = MonitoringClient() - - monitor_inference_results = [] - monitor_service_ids = [] - - # sleep(10) - time_interval_start = None - - while True: - # get all information from the inference_results queue - # deserialize the inference results - # for i in range(len(monitor_inference_results)): - # monitor_inference_results[i]["output"]["service_id"] = Parse( - # monitor_inference_results[i]["output"]["service_id"], ServiceId() - # ) - # monitor_inference_results[i]["output"]["endpoint_id"] = Parse( - # monitor_inference_results[i]["output"]["endpoint_id"], EndPointId() - # ) - - LOGGER.debug("Sleeping for %s seconds", self.MONITORED_KPIS_TIME_INTERVAL_AGG) - sleep(self.MONITORED_KPIS_TIME_INTERVAL_AGG) - - for i in range(service_ids.qsize()): - new_service_id = service_ids.get() - service_id = Parse(new_service_id, ServiceId()) - monitor_service_ids.append(service_id) - - for i in range(inference_results.qsize()): - new_inference_result = inference_results.get() - new_inference_result["output"]["service_id"] = Parse( - new_inference_result["output"]["service_id"], ServiceId() - ) - new_inference_result["output"]["endpoint_id"] = Parse( - new_inference_result["output"]["endpoint_id"], EndPointId() - ) - - monitor_inference_results.append(new_inference_result) - - LOGGER.debug("monitor_inference_results: {}".format(len(monitor_inference_results))) - LOGGER.debug("monitor_service_ids: {}".format(len(monitor_service_ids))) - - while len(monitor_inference_results) == 0: - LOGGER.debug("monitor_inference_results is empty, waiting for new inference results") - - for i in range(inference_results.qsize()): - new_inference_result = inference_results.get() - new_inference_result["output"]["service_id"] = Parse( - new_inference_result["output"]["service_id"], ServiceId() - ) - new_inference_result["output"]["endpoint_id"] = Parse( - new_inference_result["output"]["endpoint_id"], EndPointId() - ) - - monitor_inference_results.append(new_inference_result) - - sleep(1) - - for service_id in monitor_service_ids: - LOGGER.debug("service_id: {}".format(service_id)) - - time_interval = self.MONITORED_KPIS_TIME_INTERVAL_AGG - # time_interval_start = datetime.utcnow() - - # assign the timestamp of the first inference result to the time_interval_start - if time_interval_start is None: - time_interval_start = monitor_inference_results[0]["timestamp"] - else: - time_interval_start = time_interval_start + timedelta(seconds=time_interval) - - # add time_interval to the current time to get the time interval end - time_interval_end = time_interval_start + timedelta(seconds=time_interval) - - # delete the inference results that are previous to the time interval start - deleted_items = [] - - for i in range(len(monitor_inference_results)): - if monitor_inference_results[i]["timestamp"] < time_interval_start: - deleted_items.append(i) - - LOGGER.debug("deleted_items: {}".format(deleted_items)) - - for i in range(len(deleted_items)): - monitor_inference_results.pop(deleted_items[i] - i) - - if len(monitor_inference_results) == 0: - break - - LOGGER.debug("time_interval_start: {}".format(time_interval_start)) - LOGGER.debug("time_interval_end: {}".format(time_interval_end)) - - # L3 security status - kpi_security_status = Kpi() - kpi_security_status.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_security_status"]["kpi_id"]) - - # get the output.tag of the ML model of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable - outputs_last_time_interval = [] - - for i in range(len(monitor_inference_results)): - if ( - monitor_inference_results[i]["timestamp"] >= time_interval_start - and monitor_inference_results[i]["timestamp"] < time_interval_end - and monitor_inference_results[i]["output"]["service_id"] == service_id - and service_id.service_uuid.uuid in self.monitored_kpis["l3_security_status"]["service_ids"] - ): - outputs_last_time_interval.append(monitor_inference_results[i]["output"]["tag"]) - - kpi_security_status.kpi_value.int32Val = ( - 0 if np.all(outputs_last_time_interval == self.NORMAL_CLASS) else 1 - ) - - # L3 ML model confidence - kpi_conf = Kpi() - kpi_conf.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_ml_model_confidence"]["kpi_id"]) - - # get the output.confidence of the ML model of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable - confidences_normal_last_time_interval = [] - confidences_crypto_last_time_interval = [] - - for i in range(len(monitor_inference_results)): - LOGGER.debug("monitor_inference_results[i]: {}".format(monitor_inference_results[i])) - - if ( - monitor_inference_results[i]["timestamp"] >= time_interval_start - and monitor_inference_results[i]["timestamp"] < time_interval_end - and monitor_inference_results[i]["output"]["service_id"] == service_id - and service_id.service_uuid.uuid - in self.monitored_kpis["l3_ml_model_confidence"]["service_ids"] - ): - if monitor_inference_results[i]["output"]["tag"] == self.NORMAL_CLASS: - confidences_normal_last_time_interval.append( - monitor_inference_results[i]["output"]["confidence"] - ) - elif monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: - confidences_crypto_last_time_interval.append( - monitor_inference_results[i]["output"]["confidence"] - ) - else: - LOGGER.debug("Unknown tag: {}".format(monitor_inference_results[i]["output"]["tag"])) - - LOGGER.debug("confidences_normal_last_time_interval: {}".format(confidences_normal_last_time_interval)) - LOGGER.debug("confidences_crypto_last_time_interval: {}".format(confidences_crypto_last_time_interval)) - - kpi_conf.kpi_value.floatVal = ( - np.mean(confidences_crypto_last_time_interval) - if np.all(outputs_last_time_interval == self.CRYPTO_CLASS) - else np.mean(confidences_normal_last_time_interval) - ) - - # L3 unique attack connections - kpi_unique_attack_conns = Kpi() - kpi_unique_attack_conns.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_unique_attack_conns"]["kpi_id"]) - - # get the number of unique attack connections (grouping by origin IP, origin port, destination IP, destination port) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable - num_unique_attack_conns_last_time_interval = 0 - unique_attack_conns_last_time_interval = [] - - for i in range(len(monitor_inference_results)): - if ( - monitor_inference_results[i]["timestamp"] >= time_interval_start - and monitor_inference_results[i]["timestamp"] < time_interval_end - and monitor_inference_results[i]["output"]["service_id"] == service_id - and service_id.service_uuid.uuid - in self.monitored_kpis["l3_unique_attack_conns"]["service_ids"] - ): - if monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: - current_attack_conn = { - "ip_o": monitor_inference_results[i]["output"]["ip_o"], - "port_o": monitor_inference_results[i]["output"]["port_o"], - "ip_d": monitor_inference_results[i]["output"]["ip_d"], - "port_d": monitor_inference_results[i]["output"]["port_d"], - } - - is_unique = True - - for j in range(len(unique_attack_conns_last_time_interval)): - if current_attack_conn == unique_attack_conns_last_time_interval[j]: - is_unique = False - - if is_unique: - num_unique_attack_conns_last_time_interval += 1 - unique_attack_conns_last_time_interval.append(current_attack_conn) - - kpi_unique_attack_conns.kpi_value.int32Val = num_unique_attack_conns_last_time_interval - - # L3 unique compromised clients - kpi_unique_compromised_clients = Kpi() - kpi_unique_compromised_clients.kpi_id.kpi_id.CopyFrom( - self.monitored_kpis["l3_unique_compromised_clients"]["kpi_id"] - ) - - # get the number of unique compromised clients (grouping by origin IP) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable - num_unique_compromised_clients_last_time_interval = 0 - unique_compromised_clients_last_time_interval = [] - - for i in range(len(monitor_inference_results)): - if ( - monitor_inference_results[i]["timestamp"] >= time_interval_start - and monitor_inference_results[i]["timestamp"] < time_interval_end - and monitor_inference_results[i]["output"]["service_id"] == service_id - and service_id.service_uuid.uuid - in self.monitored_kpis["l3_unique_compromised_clients"]["service_ids"] - ): - if monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: - if ( - monitor_inference_results[i]["output"]["ip_o"] - not in unique_compromised_clients_last_time_interval - ): - unique_compromised_clients_last_time_interval.append( - monitor_inference_results[i]["output"]["ip_o"] - ) - num_unique_compromised_clients_last_time_interval += 1 - - kpi_unique_compromised_clients.kpi_value.int32Val = num_unique_compromised_clients_last_time_interval - - # L3 unique attackers - kpi_unique_attackers = Kpi() - kpi_unique_attackers.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_unique_attackers"]["kpi_id"]) - - # get the number of unique attackers (grouping by destination ip) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable - num_unique_attackers_last_time_interval = 0 - unique_attackers_last_time_interval = [] - - for i in range(len(monitor_inference_results)): - if ( - monitor_inference_results[i]["timestamp"] >= time_interval_start - and monitor_inference_results[i]["timestamp"] < time_interval_end - and monitor_inference_results[i]["output"]["service_id"] == service_id - and service_id.service_uuid.uuid in self.monitored_kpis["l3_unique_attackers"]["service_ids"] - ): - if monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: - if ( - monitor_inference_results[i]["output"]["ip_d"] - not in unique_attackers_last_time_interval - ): - unique_attackers_last_time_interval.append( - monitor_inference_results[i]["output"]["ip_d"] - ) - num_unique_attackers_last_time_interval += 1 - - kpi_unique_attackers.kpi_value.int32Val = num_unique_attackers_last_time_interval - - timestamp = Timestamp() - timestamp.timestamp = timestamp_utcnow_to_float() - - kpi_security_status.timestamp.CopyFrom(timestamp) - kpi_conf.timestamp.CopyFrom(timestamp) - kpi_unique_attack_conns.timestamp.CopyFrom(timestamp) - kpi_unique_compromised_clients.timestamp.CopyFrom(timestamp) - kpi_unique_attackers.timestamp.CopyFrom(timestamp) - - LOGGER.debug("Sending KPIs to monitoring server") - - LOGGER.debug("kpi_security_status: {}".format(kpi_security_status)) - LOGGER.debug("kpi_conf: {}".format(kpi_conf)) - LOGGER.debug("kpi_unique_attack_conns: {}".format(kpi_unique_attack_conns)) - LOGGER.debug("kpi_unique_compromised_clients: {}".format(kpi_unique_compromised_clients)) - LOGGER.debug("kpi_unique_attackers: {}".format(kpi_unique_attackers)) - - kpi_security_status.kpi_value.floatVal = 500 - kpi_conf.kpi_value.floatVal = 500 - kpi_unique_attack_conns.kpi_value.floatVal = 500 - kpi_unique_compromised_clients.kpi_value.floatVal = 500 - kpi_unique_attackers.kpi_value.floatVal = 500 - - try: - self.monitoring_client_test.IncludeKpi(kpi_security_status) - self.monitoring_client_test.IncludeKpi(kpi_conf) - self.monitoring_client_test.IncludeKpi(kpi_unique_attack_conns) - self.monitoring_client_test.IncludeKpi(kpi_unique_compromised_clients) - self.monitoring_client_test.IncludeKpi(kpi_unique_attackers) - except Exception as e: - LOGGER.debug("Error sending KPIs to monitoring server: {}".format(e)) - - # self.monitoring_client_test.IncludeKpi(kpi_security_status) - # self.monitoring_client_test.IncludeKpi(kpi_conf) - # self.monitoring_client_test.IncludeKpi(kpi_unique_attack_conns) - # self.monitoring_client_test.IncludeKpi(kpi_unique_compromised_clients) - # self.monitoring_client_test.IncludeKpi(kpi_unique_attackers) - - LOGGER.debug("KPIs sent to monitoring server") - - def monitor_kpis_test( + def monitor_kpis( self, ): monitor_inference_results = self.inference_results @@ -522,8 +222,25 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto LOGGER.debug("time_interval_start: {}".format(self.time_interval_start)) LOGGER.debug("time_interval_end: {}".format(self.time_interval_end)) + # delete all inference results that are older than the time_interval_start + delete_inference_results = [] + + for i in range(len(monitor_inference_results)): + inference_result_timestamp = monitor_inference_results[i]["timestamp"] + + if inference_result_timestamp < self.time_interval_start: + delete_inference_results.append(monitor_inference_results[i]) + + if len(delete_inference_results) > 0: + monitor_inference_results = [ + inference_result + for inference_result in monitor_inference_results + if inference_result not in delete_inference_results + ] + LOGGER.debug(f"Cleaned inference results. {len(delete_inference_results)} inference results deleted") + # check if there is at least one inference result in monitor_inference_results in the current time_interval - non_empty_time_interval = False + num_inference_results_in_time_interval = 0 for i in range(len(monitor_inference_results)): inference_result_timestamp = monitor_inference_results[i]["timestamp"] @@ -532,8 +249,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto inference_result_timestamp >= self.time_interval_start and inference_result_timestamp < self.time_interval_end ): - non_empty_time_interval = True - break + num_inference_results_in_time_interval += 1 + + if num_inference_results_in_time_interval > 0: + non_empty_time_interval = True + LOGGER.debug( + f"Current time interval is not empty (there are {num_inference_results_in_time_interval} inference results" + ) + else: + non_empty_time_interval = False + LOGGER.debug("Current time interval is empty. No KPIs will be reported.") if non_empty_time_interval: for service_id in monitor_service_ids: @@ -725,15 +450,9 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto except Exception as e: LOGGER.debug("Error sending KPIs to monitoring server: {}".format(e)) - # self.monitoring_client_test.IncludeKpi(kpi_security_status) - # self.monitoring_client_test.IncludeKpi(kpi_conf) - # self.monitoring_client_test.IncludeKpi(kpi_unique_attack_conns) - # self.monitoring_client_test.IncludeKpi(kpi_unique_compromised_clients) - # self.monitoring_client_test.IncludeKpi(kpi_unique_attackers) - LOGGER.debug("KPIs sent to monitoring server") else: - LOGGER.debug("No KPIs to send to monitoring server") + LOGGER.debug("No KPIs sent to monitoring server") """ Classify connection as standard traffic or cryptomining attack and return results @@ -800,26 +519,11 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto """ def SendInput(self, request, context): - # Store the data sent in the request - # Protobuff messages are NOT pickable, so we need to serialize them first - # self.inference_values.put({"request": request, "timestamp": datetime.now()}) - # Perform inference with the data sent in the request logging.info("Performing inference...") cryptomining_detector_output = self.make_inference(request) logging.info("Inference performed correctly") - # # Store the results of the inference that will be later used to monitor the KPIs - # # Protobuff messages are NOT pickable, so we need to serialize them first - # cryptomining_detector_output_serialized = copy.deepcopy(cryptomining_detector_output) - # cryptomining_detector_output_serialized["service_id"] = MessageToJson( - # request.service_id, preserving_proto_field_name=True - # ) - # cryptomining_detector_output_serialized["endpoint_id"] = MessageToJson( - # request.endpoint_id, preserving_proto_field_name=True - # ) - - # self.inference_results.put({"output": cryptomining_detector_output_serialized, "timestamp": datetime.now()}) self.inference_results.append({"output": cryptomining_detector_output, "timestamp": datetime.now()}) service_id = request.service_id @@ -830,29 +534,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto if service_id not in self.service_ids: self.create_kpis(service_id, device_id, endpoint_id) self.service_ids.append(service_id) - # self.monitored_service_ids.put(MessageToJson(service_id, preserving_proto_field_name=True)) - - self.monitor_kpis_test() - # if self.kpi_test is None: - # _create_kpi_request = KpiDescriptor() - # _create_kpi_request.kpi_description = "KPI Description Test" - # _create_kpi_request.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN - # _create_kpi_request.device_id.device_uuid.uuid = "DEVUPM" # pylint: disable=maybe-no-member - # _create_kpi_request.service_id.service_uuid.uuid = "SERVUPM" # pylint: disable=maybe-no-member - # _create_kpi_request.endpoint_id.endpoint_uuid.uuid = "ENDUPM" # pylint: disable=maybe-no-member - # _create_kpi_request.connection_id.connection_uuid.uuid = "CONUPM" # pylint: disable=maybe-no-member - # _create_kpi_request.slice_id.slice_uuid.uuid = "SLIUPM" # pylint: disable=maybe-no-member - - # self.kpi_test = self.monitoring_client.SetKpi(_create_kpi_request) - # LOGGER.debug("KPI Test: {}".format(self.kpi_test)) - - # _include_kpi_request = Kpi() - # _include_kpi_request.kpi_id.kpi_id.uuid = self.kpi_test.kpi_id.uuid - # _include_kpi_request.timestamp.timestamp = timestamp_utcnow_to_float() - # _include_kpi_request.kpi_value.floatVal = 500 - - # self.monitoring_client.IncludeKpi(_include_kpi_request) + self.monitor_kpis() # Only notify Attack Mitigator when a cryptomining connection has been detected if cryptomining_detector_output["tag_name"] == "Crypto": @@ -880,13 +563,3 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto logging.info("No attack detected") return Empty(message="Ok, information received (no attack detected)") - - -""" - def GetOutput(self, request, context): - logging.info("Returning inference output...") - k = np.multiply(self.inference_values, [2]) - k = np.sum(k) - - return self.make_inference(k) -"""