# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function from datetime import datetime from datetime import timedelta import os import grpc import numpy as np import onnxruntime as rt import logging from time import sleep from common.proto.l3_centralizedattackdetector_pb2 import Empty from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorServicer from common.proto.l3_attackmitigator_pb2 import L3AttackmitigatorOutput from common.proto.l3_attackmitigator_pb2_grpc import L3AttackmitigatorStub from common.proto.monitoring_pb2 import KpiDescriptor from common.proto.kpi_sample_types_pb2 import KpiSampleType from monitoring.client.MonitoringClient import MonitoringClient from common.proto.monitoring_pb2 import Kpi from common.tools.timestamp.Converters import timestamp_utcnow_to_float from common.proto.context_pb2 import Timestamp, ServiceId, EndPointId, SliceId, DeviceId from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigatorClient # from context.client.ContextClient import ContextClient from multiprocessing import Process, Queue from google.protobuf.json_format import MessageToJson, Parse import copy import uuid LOGGER = logging.getLogger(__name__) current_dir = os.path.dirname(os.path.abspath(__file__)) MODEL_FILE = os.path.join(current_dir, "ml_model/crypto_5g_rf_spider_features.onnx") class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer): """ Initialize variables, prediction model and clients of components used by CAD """ def __init__(self): LOGGER.info("Creating Centralized Attack Detector Service") self.inference_values = Queue() self.inference_results = Queue() self.model = rt.InferenceSession(MODEL_FILE) self.input_name = self.model.get_inputs()[0].name self.label_name = self.model.get_outputs()[0].name self.prob_name = self.model.get_outputs()[1].name self.monitoring_client = MonitoringClient() self.service_ids = [] self.monitored_service_ids = Queue() self.monitored_kpis = { "l3_security_status": { "kpi_id": None, "description": "L3 - Confidence of the cryptomining detector in the security status in the last time interval of the service {service_id}", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_SECURITY_STATUS_CRYPTO, "service_ids": [], }, "l3_ml_model_confidence": { "kpi_id": None, "description": "L3 - Security status of the service in a time interval of the service {service_id} (“0” if no attack has been detected on the service and “1” if a cryptomining attack has been detected)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_ML_CONFIDENCE, "service_ids": [], }, "l3_unique_attack_conns": { "kpi_id": None, "description": "L3 - Number of attack connections detected in a time interval of the service {service_id} (attacks of the same connection [origin IP, origin port, destination IP and destination port] are only considered once)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_UNIQUE_ATTACK_CONNS, "service_ids": [], }, "l3_unique_compromised_clients": { "kpi_id": None, "description": "L3 - Number of unique compromised clients of the service in a time interval of the service {service_id} (attacks from the same origin IP are only considered once)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_UNIQUE_COMPROMISED_CLIENTS, "service_ids": [], }, "l3_unique_attackers": { "kpi_id": None, "description": "L3 - number of unique attackers of the service in a time interval of the service {service_id} (attacks from the same destination IP are only considered once)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_UNIQUE_ATTACKERS, "service_ids": [], }, } self.attackmitigator_client = l3_attackmitigatorClient() # self.context_client = ContextClient() # self.context_id = "admin" # Environment variables self.CLASSIFICATION_THRESHOLD = os.getenv("CAD_CLASSIFICATION_THRESHOLD", 0.5) self.MONITORED_KPIS_TIME_INTERVAL_AGG = os.getenv("MONITORED_KPIS_TIME_INTERVAL_AGG", 30) # Constants self.NORMAL_CLASS = 0 self.CRYPTO_CLASS = 1 # start monitoring process self.monitoring_process = Process( target=self.monitor_kpis, args=( self.monitored_service_ids, self.inference_results, ), ) # self.monitoring_process.start() """ Create a monitored KPI for a specific service and add it to the Monitoring Client -input: + client: Monitoring Client object where the KPI will be tracked + service_id: service ID where the KPI will be monitored + kpi_description: description of the KPI + kpi_sample_type: KPI sample type of the KPI (it must be defined in the kpi_sample_types.proto file) -output: KPI identifier representing the KPI """ def create_kpi( self, service_id, device_id, endpoint_id, # slice_id, kpi_name, kpi_description, kpi_sample_type, ): kpidescriptor = KpiDescriptor() kpidescriptor.kpi_description = kpi_description kpidescriptor.service_id.service_uuid.uuid = service_id.service_uuid.uuid kpidescriptor.device_id.device_uuid.uuid = device_id.device_uuid.uuid kpidescriptor.endpoint_id.endpoint_uuid.uuid = endpoint_id.endpoint_uuid.uuid # kpidescriptor.slice_id.slice_uuid.uuid = slice_id.slice_uuid.uuid kpidescriptor.kpi_sample_type = kpi_sample_type new_kpi = self.monitoring_client.SetKpi(kpidescriptor) LOGGER.info("Created KPI {}".format(kpi_name)) return new_kpi """ Create the monitored KPIs for a specific service, add them to the Monitoring Client and store their identifiers in the monitored_kpis dictionary -input: + service_id: service ID where the KPIs will be monitored -output: None """ def create_kpis(self, service_id, device_id, endpoint_id): LOGGER.info("Creating KPIs for service {}".format(service_id)) # for now, all the KPIs are created for all the services from which requests are received for kpi in self.monitored_kpis: # slice_ids_list = self.context_client.ListSliceIds(self.context_id)[0] # # generate random slice_id # slice_id = SliceId() # slice_id.slice_uuid.uuid = str(uuid.uuid4()) # generate random device_id device_id = DeviceId() device_id.device_uuid.uuid = str(uuid.uuid4()) created_kpi = self.create_kpi( service_id, device_id, endpoint_id, # slice_id, kpi, self.monitored_kpis[kpi]["description"].format(service_id=service_id.service_uuid.uuid), self.monitored_kpis[kpi]["kpi_sample_type"], ) self.monitored_kpis[kpi]["kpi_id"] = created_kpi.kpi_id self.monitored_kpis[kpi]["service_ids"].append(service_id.service_uuid.uuid) self.monitoring_process.start() def monitor_kpis(self, service_ids, inference_results): self.monitoring_client_test = MonitoringClient() monitor_inference_results = [] monitor_service_ids = [] # sleep(10) time_interval_start = None while True: # get all information from the inference_results queue # deserialize the inference results # for i in range(len(monitor_inference_results)): # monitor_inference_results[i]["output"]["service_id"] = Parse( # monitor_inference_results[i]["output"]["service_id"], ServiceId() # ) # monitor_inference_results[i]["output"]["endpoint_id"] = Parse( # monitor_inference_results[i]["output"]["endpoint_id"], EndPointId() # ) LOGGER.debug("Sleeping for %s seconds", self.MONITORED_KPIS_TIME_INTERVAL_AGG) sleep(self.MONITORED_KPIS_TIME_INTERVAL_AGG) for i in range(service_ids.qsize()): new_service_id = service_ids.get() service_id = Parse(new_service_id, ServiceId()) monitor_service_ids.append(service_id) for i in range(inference_results.qsize()): new_inference_result = inference_results.get() new_inference_result["output"]["service_id"] = Parse( new_inference_result["output"]["service_id"], ServiceId() ) new_inference_result["output"]["endpoint_id"] = Parse( new_inference_result["output"]["endpoint_id"], EndPointId() ) monitor_inference_results.append(new_inference_result) LOGGER.debug("monitor_inference_results: {}".format(len(monitor_inference_results))) LOGGER.debug("monitor_service_ids: {}".format(len(monitor_service_ids))) while len(monitor_inference_results) == 0: LOGGER.debug("monitor_inference_results is empty, waiting for new inference results") for i in range(inference_results.qsize()): new_inference_result = inference_results.get() new_inference_result["output"]["service_id"] = Parse( new_inference_result["output"]["service_id"], ServiceId() ) new_inference_result["output"]["endpoint_id"] = Parse( new_inference_result["output"]["endpoint_id"], EndPointId() ) monitor_inference_results.append(new_inference_result) sleep(1) for service_id in monitor_service_ids: LOGGER.debug("service_id: {}".format(service_id)) time_interval = self.MONITORED_KPIS_TIME_INTERVAL_AGG # time_interval_start = datetime.utcnow() # assign the timestamp of the first inference result to the time_interval_start if time_interval_start is None: time_interval_start = monitor_inference_results[0]["timestamp"] else: time_interval_start = time_interval_start + timedelta(seconds=time_interval) # add time_interval to the current time to get the time interval end time_interval_end = time_interval_start + timedelta(seconds=time_interval) # delete the inference results that are previous to the time interval start deleted_items = [] for i in range(len(monitor_inference_results)): if monitor_inference_results[i]["timestamp"] < time_interval_start: deleted_items.append(i) LOGGER.debug("deleted_items: {}".format(deleted_items)) for i in range(len(deleted_items)): monitor_inference_results.pop(deleted_items[i] - i) if len(monitor_inference_results) == 0: break LOGGER.debug("time_interval_start: {}".format(time_interval_start)) LOGGER.debug("time_interval_end: {}".format(time_interval_end)) # L3 security status kpi_security_status = Kpi() kpi_security_status.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_security_status"]["kpi_id"]) # get the output.tag of the ML model of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable outputs_last_time_interval = [] for i in range(len(monitor_inference_results)): if ( monitor_inference_results[i]["timestamp"] >= time_interval_start and monitor_inference_results[i]["timestamp"] < time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_security_status"]["service_ids"] ): outputs_last_time_interval.append(monitor_inference_results[i]["output"]["tag"]) kpi_security_status.kpi_value.int32Val = ( 0 if np.all(outputs_last_time_interval == self.NORMAL_CLASS) else 1 ) # L3 ML model confidence kpi_conf = Kpi() kpi_conf.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_ml_model_confidence"]["kpi_id"]) # get the output.confidence of the ML model of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable confidences_normal_last_time_interval = [] confidences_crypto_last_time_interval = [] for i in range(len(monitor_inference_results)): LOGGER.debug("monitor_inference_results[i]: {}".format(monitor_inference_results[i])) if ( monitor_inference_results[i]["timestamp"] >= time_interval_start and monitor_inference_results[i]["timestamp"] < time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_ml_model_confidence"]["service_ids"] ): if monitor_inference_results[i]["output"]["tag"] == self.NORMAL_CLASS: confidences_normal_last_time_interval.append( monitor_inference_results[i]["output"]["confidence"] ) elif monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: confidences_crypto_last_time_interval.append( monitor_inference_results[i]["output"]["confidence"] ) else: LOGGER.debug("Unknown tag: {}".format(monitor_inference_results[i]["output"]["tag"])) LOGGER.debug("confidences_normal_last_time_interval: {}".format(confidences_normal_last_time_interval)) LOGGER.debug("confidences_crypto_last_time_interval: {}".format(confidences_crypto_last_time_interval)) kpi_conf.kpi_value.floatVal = ( np.mean(confidences_crypto_last_time_interval) if np.all(outputs_last_time_interval == self.CRYPTO_CLASS) else np.mean(confidences_normal_last_time_interval) ) # L3 unique attack connections kpi_unique_attack_conns = Kpi() kpi_unique_attack_conns.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_unique_attack_conns"]["kpi_id"]) # get the number of unique attack connections (grouping by origin IP, origin port, destination IP, destination port) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable num_unique_attack_conns_last_time_interval = 0 unique_attack_conns_last_time_interval = [] for i in range(len(monitor_inference_results)): if ( monitor_inference_results[i]["timestamp"] >= time_interval_start and monitor_inference_results[i]["timestamp"] < time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_unique_attack_conns"]["service_ids"] ): if monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: current_attack_conn = { "ip_o": monitor_inference_results[i]["output"]["ip_o"], "port_o": monitor_inference_results[i]["output"]["port_o"], "ip_d": monitor_inference_results[i]["output"]["ip_d"], "port_d": monitor_inference_results[i]["output"]["port_d"], } for j in range(len(unique_attack_conns_last_time_interval)): if current_attack_conn == unique_attack_conns_last_time_interval[j]: break num_unique_attack_conns_last_time_interval += 1 unique_attack_conns_last_time_interval.append(current_attack_conn) kpi_unique_attack_conns.kpi_value.int32Val = num_unique_attack_conns_last_time_interval # L3 unique compromised clients kpi_unique_compromised_clients = Kpi() kpi_unique_compromised_clients.kpi_id.kpi_id.CopyFrom( self.monitored_kpis["l3_unique_attack_conns"]["kpi_id"] ) # get the number of unique compromised clients (grouping by origin IP) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable num_unique_compromised_clients_last_time_interval = 0 unique_compromised_clients_last_time_interval = [] for i in range(len(monitor_inference_results)): if ( monitor_inference_results[i]["timestamp"] >= time_interval_start and monitor_inference_results[i]["timestamp"] < time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_unique_attack_conns"]["service_ids"] ): if monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: if ( monitor_inference_results[i]["output"]["ip_o"] not in unique_compromised_clients_last_time_interval ): unique_compromised_clients_last_time_interval.append( monitor_inference_results[i]["output"]["ip_o"] ) num_unique_compromised_clients_last_time_interval += 1 kpi_unique_compromised_clients.kpi_value.int32Val = num_unique_compromised_clients_last_time_interval # L3 unique attackers kpi_unique_attackers = Kpi() kpi_unique_attackers.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_unique_attack_conns"]["kpi_id"]) # get the number of unique attackers (grouping by destination ip) of the last aggregation time interval as indicated by the self.MONITORED_KPIS_TIME_INTERVAL_AGG variable num_unique_attackers_last_time_interval = 0 unique_attackers_last_time_interval = [] for i in range(len(monitor_inference_results)): if ( monitor_inference_results[i]["timestamp"] >= time_interval_start and monitor_inference_results[i]["timestamp"] < time_interval_end and monitor_inference_results[i]["output"]["service_id"] == service_id and service_id.service_uuid.uuid in self.monitored_kpis["l3_unique_attack_conns"]["service_ids"] ): if monitor_inference_results[i]["output"]["tag"] == self.CRYPTO_CLASS: if ( monitor_inference_results[i]["output"]["ip_d"] not in unique_attackers_last_time_interval ): unique_attackers_last_time_interval.append( monitor_inference_results[i]["output"]["ip_d"] ) num_unique_attackers_last_time_interval += 1 kpi_unique_attackers.kpi_value.int32Val = num_unique_attackers_last_time_interval timestamp = Timestamp() timestamp.timestamp = timestamp_utcnow_to_float() kpi_security_status.timestamp.CopyFrom(timestamp) kpi_conf.timestamp.CopyFrom(timestamp) kpi_unique_attack_conns.timestamp.CopyFrom(timestamp) kpi_unique_compromised_clients.timestamp.CopyFrom(timestamp) kpi_unique_attackers.timestamp.CopyFrom(timestamp) LOGGER.debug("Sending KPIs to monitoring server") LOGGER.debug("kpi_security_status: {}".format(kpi_security_status)) LOGGER.debug("kpi_conf: {}".format(kpi_conf)) LOGGER.debug("kpi_unique_attack_conns: {}".format(kpi_unique_attack_conns)) LOGGER.debug("kpi_unique_compromised_clients: {}".format(kpi_unique_compromised_clients)) LOGGER.debug("kpi_unique_attackers: {}".format(kpi_unique_attackers)) _create_kpi_request = KpiDescriptor() _create_kpi_request.kpi_description = "KPI Description Test" _create_kpi_request.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN _create_kpi_request.device_id.device_uuid.uuid = "DEVUPM" # pylint: disable=maybe-no-member _create_kpi_request.service_id.service_uuid.uuid = "SERVUPM" # pylint: disable=maybe-no-member _create_kpi_request.endpoint_id.endpoint_uuid.uuid = "ENDUPM" # pylint: disable=maybe-no-member new_kpi = self.monitoring_client_test.SetKpi(_create_kpi_request) LOGGER.debug("New KPI: {}".format(new_kpi)) _include_kpi_request = Kpi() _include_kpi_request.kpi_id.kpi_id.uuid = new_kpi.kpi_id.uuid _include_kpi_request.timestamp.timestamp = timestamp_utcnow_to_float() _include_kpi_request.kpi_value.floatVal = 500 self.monitoring_client_test.IncludeKpi(_include_kpi_request) self.monitoring_client.IncludeKpi(kpi_security_status) self.monitoring_client.IncludeKpi(kpi_conf) self.monitoring_client.IncludeKpi(kpi_unique_attack_conns) self.monitoring_client.IncludeKpi(kpi_unique_compromised_clients) self.monitoring_client.IncludeKpi(kpi_unique_attackers) LOGGER.debug("KPIs sent to monitoring server") """ Classify connection as standard traffic or cryptomining attack and return results -input: + request: L3CentralizedattackdetectorMetrics object with connection features information -output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence """ def make_inference(self, request): x_data = np.array( [ [ request.c_pkts_all, request.c_ack_cnt, request.c_bytes_uniq, request.c_pkts_data, request.c_bytes_all, request.s_pkts_all, request.s_ack_cnt, request.s_bytes_uniq, request.s_pkts_data, request.s_bytes_all, ] ] ) predictions = self.model.run([self.prob_name], {self.input_name: x_data.astype(np.float32)})[0] # Gather the predicted class, the probability of that class and other relevant information required to block the attack output_message = { "confidence": None, "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"), "ip_o": request.ip_o, "ip_d": request.ip_d, "tag_name": None, "tag": None, "flow_id": request.flow_id, "protocol": request.protocol, "port_o": request.port_o, "port_d": request.port_d, "ml_id": "RandomForest", "service_id": request.service_id, "endpoint_id": request.endpoint_id, "time_start": request.time_start, "time_end": request.time_end, } if predictions[0][1] >= self.CLASSIFICATION_THRESHOLD: output_message["confidence"] = predictions[0][1] output_message["tag_name"] = "Crypto" output_message["tag"] = self.CRYPTO_CLASS else: output_message["confidence"] = predictions[0][0] output_message["tag_name"] = "Normal" output_message["tag"] = self.NORMAL_CLASS return output_message """ Receive features from Attack Mitigator, predict attack and communicate with Attack Mitigator -input: + request: L3CentralizedattackdetectorMetrics object with connection features information -output: Empty object with a message about the execution of the function """ def SendInput(self, request, context): # Store the data sent in the request # Protobuff messages are NOT pickable, so we need to serialize them first # self.inference_values.put({"request": request, "timestamp": datetime.now()}) # Perform inference with the data sent in the request logging.info("Performing inference...") cryptomining_detector_output = self.make_inference(request) logging.info("Inference performed correctly") # Store the results of the inference that will be later used to monitor the KPIs # Protobuff messages are NOT pickable, so we need to serialize them first cryptomining_detector_output_serialized = copy.deepcopy(cryptomining_detector_output) cryptomining_detector_output_serialized["service_id"] = MessageToJson( request.service_id, preserving_proto_field_name=True ) cryptomining_detector_output_serialized["endpoint_id"] = MessageToJson( request.endpoint_id, preserving_proto_field_name=True ) self.inference_results.put({"output": cryptomining_detector_output_serialized, "timestamp": datetime.now()}) service_id = request.service_id device_id = request.endpoint_id.device_id endpoint_id = request.endpoint_id # Check if a request of a new service has been received and, if so, create the monitored KPIs for that service if service_id not in self.service_ids: self.create_kpis(service_id, device_id, endpoint_id) self.service_ids.append(service_id) self.monitored_service_ids.put(MessageToJson(service_id, preserving_proto_field_name=True)) # Only notify Attack Mitigator when a cryptomining connection has been detected if cryptomining_detector_output["tag_name"] == "Crypto": logging.info("Crypto attack detected") # Notify the Attack Mitigator component about the attack logging.info( "Notifying the Attack Mitigator component about the attack in order to block the connection..." ) try: logging.info("Sending the connection information to the Attack Mitigator component...") message = L3AttackmitigatorOutput(**cryptomining_detector_output) response = self.attackmitigator_client.SendOutput(message) # logging.info("Attack Mitigator notified and received response: ", response.message) # FIX No message received logging.info("Attack Mitigator notified") return Empty(message="OK, information received and mitigator notified abou the attack") except Exception as e: logging.error("Error notifying the Attack Mitigator component about the attack: ", e) logging.error("Couldn't find l3_attackmitigator") return Empty(message="Attack Mitigator not found") else: logging.info("No attack detected") return Empty(message="Ok, information received (no attack detected)") """ def GetOutput(self, request, context): logging.info("Returning inference output...") k = np.multiply(self.inference_values, [2]) k = np.sum(k) return self.make_inference(k) """