diff --git a/proto/l3_centralizedattackdetector.proto b/proto/l3_centralizedattackdetector.proto index fc1eddbeb3995ff8fcb67b35bbd4cb42f2ca8c65..a34e57310421baf1e1e52aa83f570ff16f6124c6 100644 --- a/proto/l3_centralizedattackdetector.proto +++ b/proto/l3_centralizedattackdetector.proto @@ -24,10 +24,15 @@ service L3Centralizedattackdetector { rpc SendInputBatch (L3CentralizedattackdetectorModelInput) returns (Empty) {} } +message Feature { + float feature = 1; +} + message L3CentralizedattackdetectorMetrics { // Input sent by the DAD compoenent to the ML model integrated in the CAD component. // Machine learning model features + /* float c_pkts_all = 1; float c_ack_cnt = 2; float c_bytes_uniq = 3; @@ -37,19 +42,35 @@ message L3CentralizedattackdetectorMetrics { float s_ack_cnt = 7; float s_bytes_uniq = 8; float s_pkts_data = 9; - float s_bytes_all = 10; + float s_bytes_all = 10;*/ + + repeated Feature features = 1; + ConnectionMetadata connection_metadata = 2; + /* + string ip_o = 2; + string port_o = 3; + string ip_d = 4; + string port_d = 5; + string flow_id = 6; + context.ServiceId service_id = 7; + context.EndPointId endpoint_id = 8; + string protocol = 9; + float time_start = 10; + float time_end = 11; + */ +} - // Conection identifier - string ip_o = 11; - string port_o = 12; - string ip_d = 13; - string port_d = 14; - string flow_id = 15; - context.ServiceId service_id = 16; - context.EndPointId endpoint_id = 17; - string protocol = 18; - float time_start = 19; - float time_end = 20; +message ConnectionMetadata { + string ip_o = 1; + string port_o = 2; + string ip_d = 3; + string port_d = 4; + string flow_id = 5; + context.ServiceId service_id = 6; + context.EndPointId endpoint_id = 7; + string protocol = 8; + float time_start = 9; + float time_end = 10; } // Collection (batcb) of model inputs that will be sent to the model diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py index 67b50e25d500b0dda06d3030f1f154378f885b0b..cae86b1a2e8a267b57853bd769a5956ec4685a59 100644 --- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py @@ -19,6 +19,7 @@ from datetime import timedelta import os import numpy as np import onnxruntime as rt +import onnx as ox import logging import time @@ -55,12 +56,25 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto def __init__(self): LOGGER.info("Creating Centralized Attack Detector Service") + onnx_model = ox.load_model(MODEL_FILE) + meta = onnx_model.metadata_props.add() + meta.key = "key" + meta.value = "value" + LOGGER.debug(onnx_model.metadata_props[0]) + self.inference_values = [] self.inference_results = [] self.model = rt.InferenceSession(MODEL_FILE) + + '''self.model._model_meta = metadata_proto + meta = self.model.get_modelmeta() + LOGGER.debug(meta.description) + time.sleep(10)''' + self.input_name = self.model.get_inputs()[0].name self.label_name = self.model.get_outputs()[0].name self.prob_name = self.model.get_outputs()[1].name + self.monitoring_client = MonitoringClient() self.service_ids = [] self.monitored_kpis = { @@ -447,7 +461,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto """ def make_inference(self, request): - x_data = np.array( + '''x_data = np.array( [ [ request.c_pkts_all, @@ -462,6 +476,12 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto request.s_bytes_all, ] ] + )''' + + x_data = np.array( + [ + [feature.feature for feature in request.features] + ] ) # Print input data shape @@ -518,19 +538,19 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto output_message = { "confidence": None, "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"), - "ip_o": request.ip_o, - "ip_d": request.ip_d, + "ip_o": request.connection_metadata.ip_o, + "ip_d": request.connection_metadata.ip_d, "tag_name": None, "tag": None, - "flow_id": request.flow_id, - "protocol": request.protocol, - "port_o": request.port_o, - "port_d": request.port_d, + "flow_id": request.connection_metadata.flow_id, + "protocol": request.connection_metadata.protocol, + "port_o": request.connection_metadata.port_o, + "port_d": request.connection_metadata.port_d, "ml_id": "RandomForest", - "service_id": request.service_id, - "endpoint_id": request.endpoint_id, - "time_start": request.time_start, - "time_end": request.time_end, + "service_id": request.connection_metadata.service_id, + "endpoint_id": request.connection_metadata.endpoint_id, + "time_start": request.connection_metadata.time_start, + "time_end": request.connection_metadata.time_end, } if predictions[0][1] >= self.CLASSIFICATION_THRESHOLD: @@ -559,9 +579,9 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto self.inference_results.append({"output": cryptomining_detector_output, "timestamp": datetime.now()}) - service_id = request.service_id - device_id = request.endpoint_id.device_id - endpoint_id = request.endpoint_id + service_id = request.connection_metadata.service_id + device_id = request.connection_metadata.endpoint_id.device_id + endpoint_id = request.connection_metadata.endpoint_id # Check if a request of a new service has been received and, if so, create the monitored KPIs for that service if service_id not in self.service_ids: