# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function from datetime import datetime from datetime import timedelta import os import numpy as np import onnxruntime as rt import logging import time from common.proto.l3_centralizedattackdetector_pb2 import Empty, AutoFeatures from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorServicer from common.proto.l3_attackmitigator_pb2 import L3AttackmitigatorOutput from common.proto.monitoring_pb2 import KpiDescriptor from common.proto.kpi_sample_types_pb2 import KpiSampleType from monitoring.client.MonitoringClient import MonitoringClient from common.proto.monitoring_pb2 import Kpi from common.tools.timestamp.Converters import timestamp_utcnow_to_float from common.proto.context_pb2 import Timestamp, SliceId, ConnectionId from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigatorClient import uuid from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method import csv LOGGER = logging.getLogger(__name__) current_dir = os.path.dirname(os.path.abspath(__file__)) # Constants DEMO_MODE = False ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"] BATCH_SIZE = int(os.getenv("BATCH_SIZE", 10)) METRICS_POOL = MetricsPool("l3_centralizedattackdetector", "RPC") class ConnectionInfo: def __init__(self, ip_o, port_o, ip_d, port_d): self.ip_o = ip_o self.port_o = port_o self.ip_d = ip_d self.port_d = port_d def __eq__(self, other): return ( self.ip_o == other.ip_o and self.port_o == other.port_o and self.ip_d == other.ip_d and self.port_d == other.port_d ) def __str__(self): return "ip_o: " + self.ip_o + "\nport_o: " + self.port_o + "\nip_d: " + self.ip_d + "\nport_d: " + self.port_d class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer): """ Initialize variables, prediction model and clients of components used by CAD """ def __init__(self): LOGGER.info("Creating Centralized Attack Detector Service") self.inference_values = [] self.inference_results = [] self.cryptomining_detector_path = os.path.join(current_dir, "ml_model/cryptomining_detector/") self.cryptomining_detector_file_name = os.listdir(self.cryptomining_detector_path)[0] self.cryptomining_detector_model_path = os.path.join( self.cryptomining_detector_path, self.cryptomining_detector_file_name ) self.cryptomining_detector_model = rt.InferenceSession(self.cryptomining_detector_model_path) # Load cryptomining detector features metadata from ONNX file self.cryptomining_detector_features_metadata = list( self.cryptomining_detector_model.get_modelmeta().custom_metadata_map.values() ) self.cryptomining_detector_features_metadata = [float(x) for x in self.cryptomining_detector_features_metadata] self.cryptomining_detector_features_metadata.sort() LOGGER.info("Cryptomining Detector Features: " + str(self.cryptomining_detector_features_metadata)) LOGGER.info(f"Batch size: {BATCH_SIZE}") self.input_name = self.cryptomining_detector_model.get_inputs()[0].name self.label_name = self.cryptomining_detector_model.get_outputs()[0].name self.prob_name = self.cryptomining_detector_model.get_outputs()[1].name # Kpi values self.l3_security_status = 0 # unnecessary self.l3_ml_model_confidence = 0 self.l3_inferences_in_interval_counter = 0 self.l3_ml_model_confidence_normal = 0 self.l3_inferences_in_interval_counter_normal = 0 self.l3_ml_model_confidence_crypto = 0 self.l3_inferences_in_interval_counter_crypto = 0 self.l3_attacks = [] self.l3_unique_attack_conns = 0 self.l3_unique_compromised_clients = 0 self.l3_unique_attackers = 0 self.l3_non_empty_time_interval = False self.active_requests = [] self.monitoring_client = MonitoringClient() self.service_ids = [] self.monitored_kpis = { "l3_security_status": { "kpi_id": None, "description": "L3 - Confidence of the cryptomining detector in the security status in the last time interval of the service {service_id}", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_SECURITY_STATUS_CRYPTO, "service_ids": [], }, "l3_ml_model_confidence": { "kpi_id": None, "description": "L3 - Security status of the service in a time interval of the service {service_id} (“0” if no attack has been detected on the service and “1” if a cryptomining attack has been detected)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_ML_CONFIDENCE, "service_ids": [], }, "l3_unique_attack_conns": { "kpi_id": None, "description": "L3 - Number of attack connections detected in a time interval of the service {service_id} (attacks of the same connection [origin IP, origin port, destination IP and destination port] are only considered once)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_UNIQUE_ATTACK_CONNS, "service_ids": [], }, "l3_unique_compromised_clients": { "kpi_id": None, "description": "L3 - Number of unique compromised clients of the service in a time interval of the service {service_id} (attacks from the same origin IP are only considered once)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_UNIQUE_COMPROMISED_CLIENTS, "service_ids": [], }, "l3_unique_attackers": { "kpi_id": None, "description": "L3 - number of unique attackers of the service in a time interval of the service {service_id} (attacks from the same destination IP are only considered once)", "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_UNIQUE_ATTACKERS, "service_ids": [], }, } self.attackmitigator_client = l3_attackmitigatorClient() # Environment variables self.CLASSIFICATION_THRESHOLD = float(os.getenv("CAD_CLASSIFICATION_THRESHOLD", 0.5)) self.MONITORED_KPIS_TIME_INTERVAL_AGG = int(os.getenv("MONITORED_KPIS_TIME_INTERVAL_AGG", 60)) # Constants self.NORMAL_CLASS = 0 self.CRYPTO_CLASS = 1 self.kpi_test = None self.time_interval_start = None self.time_interval_end = None # CAD evaluation tests self.cad_inference_times = [] self.cad_num_inference_measurements = 100 # AM evaluation tests self.am_notification_times = [] # List of attack connections self.attack_connections = [] self.correct_attack_conns = 0 self.correct_predictions = 0 self.total_predictions = 0 self.false_positives = 0 self.false_negatives = 0 self.replica_uuid = uuid.uuid4() self.first_batch_request_time = 0 self.last_batch_request_time = 0 LOGGER.info("This replica's identifier is: " + str(self.replica_uuid)) self.response_times_csv_file_path = "response_times.csv" col_names = ["timestamp_first_req", "timestamp_last_req", "total_time", "batch_size"] with open(self.response_times_csv_file_path, "w", newline="") as file: writer = csv.writer(file) writer.writerow(col_names) """ Create a monitored KPI for a specific service and add it to the Monitoring Client -input: + service_id: service ID where the KPI will be monitored + kpi_name: name of the KPI + kpi_description: description of the KPI + kpi_sample_type: KPI sample type of the KPI (it must be defined in the kpi_sample_types.proto file) -output: KPI identifier representing the KPI """ def create_kpi( self, service_id, kpi_name, kpi_description, kpi_sample_type, ): kpidescriptor = KpiDescriptor() kpidescriptor.kpi_description = kpi_description kpidescriptor.service_id.service_uuid.uuid = service_id.service_uuid.uuid kpidescriptor.kpi_sample_type = kpi_sample_type new_kpi = self.monitoring_client.SetKpi(kpidescriptor) LOGGER.info("Created KPI {}".format(kpi_name)) return new_kpi """ Create the monitored KPIs for a specific service, add them to the Monitoring Client and store their identifiers in the monitored_kpis dictionary -input: + service_id: service ID where the KPIs will be monitored -output: None """ def create_kpis(self, service_id, device_id, endpoint_id): LOGGER.info("Creating KPIs for service {}".format(service_id)) # for now, all the KPIs are created for all the services from which requests are received for kpi in self.monitored_kpis: # generate random slice_id slice_id = SliceId() slice_id.slice_uuid.uuid = str(uuid.uuid4()) # generate random connection_id connection_id = ConnectionId() connection_id.connection_uuid.uuid = str(uuid.uuid4()) created_kpi = self.create_kpi( service_id, kpi, self.monitored_kpis[kpi]["description"].format(service_id=service_id.service_uuid.uuid), self.monitored_kpis[kpi]["kpi_sample_type"], ) self.monitored_kpis[kpi]["kpi_id"] = created_kpi.kpi_id self.monitored_kpis[kpi]["service_ids"].append(service_id.service_uuid.uuid) LOGGER.info("Created KPIs for service {}".format(service_id)) def monitor_kpis(self): monitor_inference_results = self.inference_results monitor_service_ids = self.service_ids self.assign_timestamp(monitor_inference_results) non_empty_time_interval = self.l3_non_empty_time_interval if non_empty_time_interval: for service_id in monitor_service_ids: LOGGER.debug("service_id: {}".format(service_id)) self.monitor_compute_l3_kpi(service_id, monitor_inference_results) # Demo mode inference results are erased """if DEMO_MODE: # Delete fist half of the inference results LOGGER.debug("inference_results len: {}".format(len(self.inference_results))) self.inference_results = self.inference_results[len(self.inference_results)//2:] LOGGER.debug("inference_results len after erase: {}".format(len(self.inference_results)))""" # end = time.time() # LOGGER.debug("Time to process inference results with erase: {}".format(end - start)) LOGGER.debug("KPIs sent to monitoring server") else: LOGGER.debug("No KPIs sent to monitoring server") def assign_timestamp(self, monitor_inference_results): time_interval = self.MONITORED_KPIS_TIME_INTERVAL_AGG # assign the timestamp of the first inference result to the time_interval_start if self.time_interval_start is None: self.time_interval_start = monitor_inference_results[0]["timestamp"] LOGGER.debug("self.time_interval_start: {}".format(self.time_interval_start)) # add time_interval to the current time to get the time interval end LOGGER.debug("time_interval: {}".format(time_interval)) LOGGER.debug(timedelta(seconds=time_interval)) self.time_interval_end = self.time_interval_start + timedelta(seconds=time_interval) current_time = datetime.utcnow() LOGGER.debug("current_time: {}".format(current_time)) if current_time >= self.time_interval_end: self.time_interval_start = self.time_interval_end self.time_interval_end = self.time_interval_start + timedelta(seconds=time_interval) self.l3_security_status = 0 # unnecessary self.l3_ml_model_confidence = 0 self.l3_inferences_in_interval_counter = 0 self.l3_ml_model_confidence_normal = 0 self.l3_inferences_in_interval_counter_normal = 0 self.l3_ml_model_confidence_crypto = 0 self.l3_inferences_in_interval_counter_crypto = 0 self.l3_attacks = [] self.l3_unique_attack_conns = 0 self.l3_unique_compromised_clients = 0 self.l3_unique_attackers = 0 self.l3_non_empty_time_interval = False LOGGER.debug("time_interval_start: {}".format(self.time_interval_start)) LOGGER.debug("time_interval_end: {}".format(self.time_interval_end)) def monitor_compute_l3_kpi(self, service_id, monitor_inference_results): # L3 security status kpi_security_status = Kpi() kpi_security_status.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_security_status"]["kpi_id"]) kpi_security_status.kpi_value.int32Val = self.l3_security_status # L3 ML model confidence kpi_conf = Kpi() kpi_conf.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_ml_model_confidence"]["kpi_id"]) kpi_conf.kpi_value.floatVal = self.monitor_ml_model_confidence() # L3 unique attack connections kpi_unique_attack_conns = Kpi() kpi_unique_attack_conns.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_unique_attack_conns"]["kpi_id"]) kpi_unique_attack_conns.kpi_value.int32Val = self.l3_unique_attack_conns # L3 unique compromised clients kpi_unique_compromised_clients = Kpi() kpi_unique_compromised_clients.kpi_id.kpi_id.CopyFrom( self.monitored_kpis["l3_unique_compromised_clients"]["kpi_id"] ) kpi_unique_compromised_clients.kpi_value.int32Val = self.l3_unique_compromised_clients # L3 unique attackers kpi_unique_attackers = Kpi() kpi_unique_attackers.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_unique_attackers"]["kpi_id"]) kpi_unique_attackers.kpi_value.int32Val = self.l3_unique_attackers timestamp = Timestamp() timestamp.timestamp = timestamp_utcnow_to_float() kpi_security_status.timestamp.CopyFrom(timestamp) kpi_conf.timestamp.CopyFrom(timestamp) kpi_unique_attack_conns.timestamp.CopyFrom(timestamp) kpi_unique_compromised_clients.timestamp.CopyFrom(timestamp) kpi_unique_attackers.timestamp.CopyFrom(timestamp) LOGGER.debug("Sending KPIs to monitoring server") LOGGER.debug("kpi_security_status: {}".format(kpi_security_status)) LOGGER.debug("kpi_conf: {}".format(kpi_conf)) LOGGER.debug("kpi_unique_attack_conns: {}".format(kpi_unique_attack_conns)) LOGGER.debug("kpi_unique_compromised_clients: {}".format(kpi_unique_compromised_clients)) LOGGER.debug("kpi_unique_attackers: {}".format(kpi_unique_attackers)) try: self.monitoring_client.IncludeKpi(kpi_security_status) self.monitoring_client.IncludeKpi(kpi_conf) self.monitoring_client.IncludeKpi(kpi_unique_attack_conns) self.monitoring_client.IncludeKpi(kpi_unique_compromised_clients) self.monitoring_client.IncludeKpi(kpi_unique_attackers) except Exception as e: LOGGER.debug("Error sending KPIs to monitoring server: {}".format(e)) def monitor_ml_model_confidence(self): if self.l3_security_status == 0: return self.l3_ml_model_confidence_normal return self.l3_ml_model_confidence_crypto """ Classify connection as standard traffic or cryptomining attack and return results -input: + request: L3CentralizedattackdetectorMetrics object with connection features information -output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence """ def perform_inference(self, request): x_data = np.array([[feature.feature for feature in request.features]]) # Print input data shape LOGGER.debug("x_data.shape: {}".format(x_data.shape)) # Get batch size batch_size = x_data.shape[0] # Print batch size LOGGER.debug("batch_size: {}".format(batch_size)) LOGGER.debug("x_data.shape: {}".format(x_data.shape)) inference_time_start = time.time() # Perform inference predictions = self.cryptomining_detector_model.run( [self.prob_name], {self.input_name: x_data.astype(np.float32)} )[0] inference_time_end = time.time() # Measure inference time inference_time = inference_time_end - inference_time_start self.cad_inference_times.append(inference_time) if len(self.cad_inference_times) > self.cad_num_inference_measurements: inference_times_np_array = np.array(self.cad_inference_times) np.save(f"inference_times_{batch_size}.npy", inference_times_np_array) avg_inference_time = np.mean(inference_times_np_array) max_inference_time = np.max(inference_times_np_array) min_inference_time = np.min(inference_times_np_array) std_inference_time = np.std(inference_times_np_array) median_inference_time = np.median(inference_times_np_array) LOGGER.debug("Average inference time: {}".format(avg_inference_time)) LOGGER.debug("Max inference time: {}".format(max_inference_time)) LOGGER.debug("Min inference time: {}".format(min_inference_time)) LOGGER.debug("Standard deviation inference time: {}".format(std_inference_time)) LOGGER.debug("Median inference time: {}".format(median_inference_time)) with open(f"inference_times_stats_{batch_size}.txt", "w") as f: f.write("Average inference time: {}\n".format(avg_inference_time)) f.write("Max inference time: {}\n".format(max_inference_time)) f.write("Min inference time: {}\n".format(min_inference_time)) f.write("Standard deviation inference time: {}\n".format(std_inference_time)) f.write("Median inference time: {}\n".format(median_inference_time)) # Gather the predicted class, the probability of that class and other relevant information required to block the attack output_message = { "confidence": None, "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"), "ip_o": request.connection_metadata.ip_o, "ip_d": request.connection_metadata.ip_d, "tag_name": None, "tag": None, "flow_id": request.connection_metadata.flow_id, "protocol": request.connection_metadata.protocol, "port_o": request.connection_metadata.port_o, "port_d": request.connection_metadata.port_d, "ml_id": self.cryptomining_detector_file_name, "service_id": request.connection_metadata.service_id, "endpoint_id": request.connection_metadata.endpoint_id, "time_start": request.connection_metadata.time_start, "time_end": request.connection_metadata.time_end, } if predictions[0][1] >= self.CLASSIFICATION_THRESHOLD: output_message["confidence"] = predictions[0][1] output_message["tag_name"] = "Crypto" output_message["tag"] = self.CRYPTO_CLASS else: output_message["confidence"] = predictions[0][0] output_message["tag_name"] = "Normal" output_message["tag"] = self.NORMAL_CLASS return output_message """ Classify connection as standard traffic or cryptomining attack and return results -input: + request: L3CentralizedattackdetectorMetrics object with connection features information -output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence """ def perform_distributed_inference(self, requests): batch_size = len(requests) # Create an empty array to hold the input data x_data = np.empty((batch_size, len(requests[0].features))) # Fill in the input data array with features from each request for i, request in enumerate(requests): x_data[i] = [feature.feature for feature in request.features] # Print input data shape LOGGER.debug("x_data.shape: {}".format(x_data.shape)) inference_time_start = time.time() # Perform inference predictions = self.cryptomining_detector_model.run( [self.prob_name], {self.input_name: x_data.astype(np.float32)} )[0] inference_time_end = time.time() # Measure inference time inference_time = inference_time_end - inference_time_start self.cad_inference_times.append(inference_time) if len(self.cad_inference_times) > self.cad_num_inference_measurements: inference_times_np_array = np.array(self.cad_inference_times) np.save(f"inference_times_{batch_size}.npy", inference_times_np_array) avg_inference_time = np.mean(inference_times_np_array) max_inference_time = np.max(inference_times_np_array) min_inference_time = np.min(inference_times_np_array) std_inference_time = np.std(inference_times_np_array) median_inference_time = np.median(inference_times_np_array) LOGGER.debug("Average inference time: {}".format(avg_inference_time)) LOGGER.debug("Max inference time: {}".format(max_inference_time)) LOGGER.debug("Min inference time: {}".format(min_inference_time)) LOGGER.debug("Standard deviation inference time: {}".format(std_inference_time)) LOGGER.debug("Median inference time: {}".format(median_inference_time)) with open(f"inference_times_stats_{batch_size}.txt", "w") as f: f.write("Average inference time: {}\n".format(avg_inference_time)) f.write("Max inference time: {}\n".format(max_inference_time)) f.write("Min inference time: {}\n".format(min_inference_time)) f.write("Standard deviation inference time: {}\n".format(std_inference_time)) f.write("Median inference time: {}\n".format(median_inference_time)) # Gather the predicted class, the probability of that class and other relevant information required to block the attack output_messages = [] for i, request in enumerate(requests): output_messages.append( { "confidence": None, "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"), "ip_o": request.connection_metadata.ip_o, "ip_d": request.connection_metadata.ip_d, "tag_name": None, "tag": None, "flow_id": request.connection_metadata.flow_id, "protocol": request.connection_metadata.protocol, "port_o": request.connection_metadata.port_o, "port_d": request.connection_metadata.port_d, "ml_id": self.cryptomining_detector_file_name, "service_id": request.connection_metadata.service_id, "endpoint_id": request.connection_metadata.endpoint_id, "time_start": request.connection_metadata.time_start, "time_end": request.connection_metadata.time_end, } ) if predictions[i][1] >= self.CLASSIFICATION_THRESHOLD: output_messages[i]["confidence"] = predictions[i][1] output_messages[i]["tag_name"] = "Crypto" output_messages[i]["tag"] = self.CRYPTO_CLASS else: output_messages[i]["confidence"] = predictions[i][0] output_messages[i]["tag_name"] = "Normal" output_messages[i]["tag"] = self.NORMAL_CLASS return output_messages """ Receive features from Attack Mitigator, predict attack and communicate with Attack Mitigator -input: + request: L3CentralizedattackdetectorMetrics object with connection features information -output: Empty object with a message about the execution of the function """ @safe_and_metered_rpc_method(METRICS_POOL, LOGGER) def AnalyzeConnectionStatistics(self, request, context): # Perform inference with the data sent in the request if len(self.active_requests) == 0: self.first_batch_request_time = time.time() self.active_requests.append(request) LOGGER.debug("active_requests length: {}".format(len(self.active_requests))) LOGGER.debug("BATCH_SIZE: {}".format(BATCH_SIZE)) LOGGER.debug(len(self.active_requests) == BATCH_SIZE) LOGGER.debug("type(len(self.active_requests)): {}".format(type(len(self.active_requests)))) LOGGER.debug("type(BATCH_SIZE): {}".format(type(BATCH_SIZE))) if len(self.active_requests) >= BATCH_SIZE: LOGGER.debug("Performing inference... {}".format(self.replica_uuid)) inference_time_start = time.time() cryptomining_detector_output = self.perform_distributed_inference(self.active_requests) inference_time_end = time.time() LOGGER.debug("Inference performed in {} seconds".format(inference_time_end - inference_time_start)) logging.info("Inference performed correctly") self.inference_results.append({"output": cryptomining_detector_output, "timestamp": datetime.now()}) LOGGER.debug("inference_results length: {}".format(len(self.inference_results))) for i, req in enumerate(self.active_requests): service_id = req.connection_metadata.service_id device_id = req.connection_metadata.endpoint_id.device_id endpoint_id = req.connection_metadata.endpoint_id # Check if a request of a new service has been received and, if so, create the monitored KPIs for that service if service_id not in self.service_ids: self.create_kpis(service_id, device_id, endpoint_id) self.service_ids.append(service_id) monitor_kpis_start = time.time() self.monitor_kpis() monitor_kpis_end = time.time() LOGGER.debug("Monitoring KPIs performed in {} seconds".format(monitor_kpis_end - monitor_kpis_start)) LOGGER.debug("cryptomining_detector_output: {}".format(cryptomining_detector_output[i])) if DEMO_MODE: self.analyze_prediction_accuracy(cryptomining_detector_output[i]["confidence"]) connection_info = ConnectionInfo( req.connection_metadata.ip_o, req.connection_metadata.port_o, req.connection_metadata.ip_d, req.connection_metadata.port_d, ) self.l3_non_empty_time_interval = True if cryptomining_detector_output[i]["tag_name"] == "Crypto": self.l3_security_status = 1 self.l3_inferences_in_interval_counter_crypto += 1 self.l3_ml_model_confidence_crypto = ( self.l3_ml_model_confidence_crypto * (self.l3_inferences_in_interval_counter_crypto - 1) + cryptomining_detector_output[i]["confidence"] ) / self.l3_inferences_in_interval_counter_crypto if connection_info not in self.l3_attacks: self.l3_attacks.append(connection_info) self.l3_unique_attack_conns += 1 self.l3_unique_compromised_clients = len(set([conn.ip_o for conn in self.l3_attacks])) self.l3_unique_attackers = len(set([conn.ip_d for conn in self.l3_attacks])) else: self.l3_inferences_in_interval_counter_normal += 1 self.l3_ml_model_confidence_normal = ( self.l3_ml_model_confidence_normal * (self.l3_inferences_in_interval_counter_normal - 1) + cryptomining_detector_output[i]["confidence"] ) / self.l3_inferences_in_interval_counter_normal # Only notify Attack Mitigator when a cryptomining connection has been detected if cryptomining_detector_output[i]["tag_name"] == "Crypto": if DEMO_MODE: self.attack_connections.append(connection_info) if connection_info.ip_o in ATTACK_IPS or connection_info.ip_d in ATTACK_IPS: self.correct_attack_conns += 1 self.correct_predictions += 1 else: LOGGER.debug("False positive: {}".format(connection_info)) self.false_positives += 1 self.total_predictions += 1 # if False: notification_time_start = time.time() LOGGER.debug("Crypto attack detected") # Notify the Attack Mitigator component about the attack logging.info( "Notifying the Attack Mitigator component about the attack in order to block the connection..." ) try: logging.info("Sending the connection information to the Attack Mitigator component...") message = L3AttackmitigatorOutput(**cryptomining_detector_output[i]) response = self.attackmitigator_client.PerformMitigation(message) notification_time_end = time.time() self.am_notification_times.append(notification_time_end - notification_time_start) LOGGER.debug(f"am_notification_times length: {len(self.am_notification_times)}") LOGGER.debug(f"last am_notification_time: {self.am_notification_times[-1]}") if len(self.am_notification_times) > 100: am_notification_times_np_array = np.array(self.am_notification_times) np.save("am_notification_times.npy", am_notification_times_np_array) avg_notification_time = np.mean(am_notification_times_np_array) max_notification_time = np.max(am_notification_times_np_array) min_notification_time = np.min(am_notification_times_np_array) std_notification_time = np.std(am_notification_times_np_array) median_notification_time = np.median(am_notification_times_np_array) LOGGER.debug("Average notification time: {}".format(avg_notification_time)) LOGGER.debug("Max notification time: {}".format(max_notification_time)) LOGGER.debug("Min notification time: {}".format(min_notification_time)) LOGGER.debug("Std notification time: {}".format(std_notification_time)) LOGGER.debug("Median notification time: {}".format(median_notification_time)) with open("am_notification_times_stats.txt", "w") as f: f.write("Average notification time: {}\n".format(avg_notification_time)) f.write("Max notification time: {}\n".format(max_notification_time)) f.write("Min notification time: {}\n".format(min_notification_time)) f.write("Std notification time: {}\n".format(std_notification_time)) f.write("Median notification time: {}\n".format(median_notification_time)) # logging.info("Attack Mitigator notified and received response: ", response.message) # FIX No message received logging.info("Attack Mitigator notified") # return Empty(message="OK, information received and mitigator notified abou the attack") except Exception as e: logging.error("Error notifying the Attack Mitigator component about the attack: ", e) logging.error("Couldn't find l3_attackmitigator") return Empty(message="Attack Mitigator not found") else: logging.info("No attack detected") if cryptomining_detector_output[i]["tag_name"] != "Crypto": if connection_info.ip_o not in ATTACK_IPS and connection_info.ip_d not in ATTACK_IPS: self.correct_predictions += 1 else: LOGGER.debug("False negative: {}".format(connection_info)) self.false_negatives += 1 self.total_predictions += 1 # return Empty(message="Ok, information received (no attack detected)") self.active_requests = [] self.last_batch_request_time = time.time() col_values = [ self.first_batch_request_time, self.last_batch_request_time, self.last_batch_request_time - self.first_batch_request_time, BATCH_SIZE, ] LOGGER.debug("col_values: {}".format(col_values)) with open(self.response_times_csv_file_path, "a", newline="") as file: writer = csv.writer(file) writer.writerow(col_values) return Empty(message="Ok, metrics processed") return Empty(message="Ok, information received") def analyze_prediction_accuracy(self, confidence): LOGGER.info("Number of Attack Connections Correctly Classified: {}".format(self.correct_attack_conns)) LOGGER.info("Number of Attack Connections: {}".format(len(self.attack_connections))) if self.total_predictions > 0: overall_detection_acc = self.correct_predictions / self.total_predictions else: overall_detection_acc = 0 LOGGER.info("Overall Detection Accuracy: {}\n".format(overall_detection_acc)) if len(self.attack_connections) > 0: cryptomining_attack_detection_acc = self.correct_attack_conns / len(self.attack_connections) else: cryptomining_attack_detection_acc = 0 LOGGER.info("Cryptomining Attack Detection Accuracy: {}".format(cryptomining_attack_detection_acc)) LOGGER.info("Cryptomining Detector Confidence: {}".format(confidence)) with open("prediction_accuracy.txt", "a") as f: LOGGER.debug("Exporting prediction accuracy and confidence") f.write("Overall Detection Accuracy: {}\n".format(overall_detection_acc)) f.write("Cryptomining Attack Detection Accuracy: {}\n".format(cryptomining_attack_detection_acc)) f.write("Total Predictions: {}\n".format(self.total_predictions)) f.write("Total Positives: {}\n".format(len(self.attack_connections))) f.write("False Positives: {}\n".format(self.false_positives)) f.write("True Negatives: {}\n".format(self.total_predictions - len(self.attack_connections))) f.write("False Negatives: {}\n".format(self.false_negatives)) f.write("Cryptomining Detector Confidence: {}\n\n".format(confidence)) f.write("Timestamp: {}\n".format(datetime.now().strftime("%d/%m/%Y %H:%M:%S"))) f.close() @safe_and_metered_rpc_method(METRICS_POOL, LOGGER) def AnalyzeBatchConnectionStatistics(self, request, context): batch_time_start = time.time() for metric in request.metrics: self.AnalyzeConnectionStatistics(metric, context) batch_time_end = time.time() with open("batch_time.txt", "a") as f: f.write(str(len(request.metrics)) + "\n") f.write(str(batch_time_end - batch_time_start) + "\n\n") f.close() logging.debug("Metrics: " + str(len(request.metrics))) logging.debug("Batch time: " + str(batch_time_end - batch_time_start)) return Empty(message="OK, information received.") """ Send features allocated in the metadata of the onnx file to the DAD -output: ONNX metadata as a list of integers """ @safe_and_metered_rpc_method(METRICS_POOL, LOGGER) def GetFeaturesIds(self, request: Empty, context): features = AutoFeatures() for feature in self.cryptomining_detector_features_metadata: features.auto_features.append(feature) return features