diff --git a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py index 8664704524ecef779235af3ca3dc765af7af4898..abb9d4f6b947c4a260ef59b155579ac40249c9b6 100644 --- a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py +++ b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py @@ -14,41 +14,164 @@ from __future__ import print_function import logging -from common.proto.l3_attackmitigator_pb2 import ( - EmptyMitigator -) -from common.proto.l3_attackmitigator_pb2_grpc import ( - L3AttackmitigatorServicer, +from common.proto.l3_centralizedattackdetector_pb2 import Empty +from common.proto.l3_attackmitigator_pb2_grpc import L3AttackmitigatorServicer +from common.proto.context_pb2 import ( + Service, + ServiceId, + ServiceConfig, + ServiceTypeEnum, + ServiceStatusEnum, + ServiceStatus, + Context, + ContextId, + Uuid, + Timestamp, + ConfigRule, + ConfigRule_Custom, + ConfigActionEnum, + Device, + DeviceId, + DeviceConfig, + DeviceOperationalStatusEnum, + DeviceDriverEnum, + EndPoint, + Link, + LinkId, + EndPoint, + EndPointId, + Topology, + TopologyId, ) +from common.proto.context_pb2_grpc import ContextServiceStub +from common.proto.service_pb2_grpc import ServiceServiceStub +from datetime import datetime +import grpc +import time +import json LOGGER = logging.getLogger(__name__) +CONTEXT_CHANNEL = "192.168.165.78:1010" +SERVICE_CHANNEL = "192.168.165.78:3030" -class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer): +class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer): def __init__(self): LOGGER.debug("Creating Servicer...") self.last_value = -1 self.last_tag = 0 - + + def GenerateRuleValue(self, ip_o, ip_d, port_o, port_d): + value = { + "ipv4:source-address": ip_o, + "ipv4:destination-address": ip_d, + "transport:source-port": port_o, + "transport:destination-port": port_d, + "forwarding-action": "DROP", + } + return value + + def GenerateContextId(self, context_id): + context_id_obj = ContextId() + uuid = Uuid() + uuid.uuid = context_id + context_id_obj.context_uuid.CopyFrom(uuid) + return context_id_obj + + def GenerateServiceId(self, service_id): + service_id_obj = ServiceId() + context_id = ContextId() + uuid = Uuid() + uuid.uuid = service_id + context_id.context_uuid.CopyFrom(uuid) + service_id_obj.context_id.CopyFrom(context_id) + service_id_obj.service_uuid.CopyFrom(uuid) + return service_id_obj + + def GetConfigRule(self, ip_o, ip_d, port_o, port_d): + config_rule = ConfigRule() + config_rule_custom = ConfigRule_Custom() + config_rule.action = ConfigActionEnum.CONFIGACTION_SET + config_rule_custom.resource_key = "test" + # config_rule_custom.resource_value = str(self.GenerateRuleValue(ip_o, ip_d, port_o, port_d)) + config_rule_custom.resource_value = json.dumps(self.GenerateRuleValue(ip_o, ip_d, port_o, port_d)) + config_rule.custom.CopyFrom(config_rule_custom) + return config_rule + def SendOutput(self, request, context): # SEND CONFIDENCE TO MITIGATION SERVER - logging.debug("") - print("Server received mitigation values...", request.confidence) + print("Server received mitigation values...", request.confidence, flush=True) + last_value = request.confidence last_tag = request.tag + + ip_o = request.ip_o + ip_d = request.ip_d + port_o = request.port_o + port_d = request.port_d + + # service_id = self.GenerateServiceId(request.service_id) + # service = GetService(service_id) + + # context_id = self.GenerateContextId("admin") + + sentinel = True + counter = 0 + + # service_id_list = self.ListServiceIds(context_id) + + # print(hello, flush = True) + # print(hello.service_ids[0].service_uuid.uuid, flush=True) + + # service_id = service_id_list.service_ids[0] + service_id = request.service_id + + print("Service id: ", service_id, flush=True) + + while sentinel: + try: + service = self.GetService(service_id) + sentinel = False + except Exception as e: + counter = counter + 1 + print("Waiting 2 seconds", counter, e, flush=True) + time.sleep(2) + + print("Service obtained from id: ", service, flush=True) + + config_rule = self.GetConfigRule(ip_o, ip_d, port_o, port_d) + + service_config = ServiceConfig() + service_config.config_rules.extend([config_rule]) + service.service_config.CopyFrom(service_config) + + print("Service with new rule: ", service, flush=True) + self.UpdateService(service) + + service2 = self.GetService(service_id) + print("Service obtained from id after updating: ", service2, flush=True) + # RETURN OK TO THE CALLER - return EmptyMitigator( - message=f"OK, received values: {last_tag} with confidence {last_value}." - ) + return Empty(message=f"OK, received values: {last_tag} with confidence {last_value}.") + + def GetService(self, service_id): + with grpc.insecure_channel(CONTEXT_CHANNEL) as channel: + stub = ContextServiceStub(channel) + return stub.GetService(service_id) + + def ListServiceIds(self, context_id): + with grpc.insecure_channel(CONTEXT_CHANNEL) as channel: + stub = ContextServiceStub(channel) + return stub.ListServiceIds(context_id) + + def UpdateService(self, service): + with grpc.insecure_channel(SERVICE_CHANNEL) as channel: + stub = ServiceServiceStub(channel) + stub.UpdateService(service) def GetMitigation(self, request, context): # GET OR PERFORM MITIGATION STRATEGY logging.debug("") print("Returing mitigation strategy...") k = self.last_value * 2 - return EmptyMitigator( - message=f"Mitigation with double confidence = {k}" - ) - - - + return Empty(message=f"Mitigation with double confidence = {k}") diff --git a/src/l3_centralizedattackdetector/Dockerfile b/src/l3_centralizedattackdetector/Dockerfile index 3db5c2b4d7e4020b727d0d3f9b106f9c4af2e6b6..0a980d8b10b45a6096a8afae03c3a2ceeb638b88 100644 --- a/src/l3_centralizedattackdetector/Dockerfile +++ b/src/l3_centralizedattackdetector/Dockerfile @@ -63,6 +63,7 @@ RUN python3 -m pip install -r requirements.txt # Add component files into working directory WORKDIR /var/teraflow COPY src/l3_centralizedattackdetector/. l3_centralizedattackdetector +COPY src/monitoring/. monitoring # Start the service ENTRYPOINT ["python", "-m", "l3_centralizedattackdetector.service"] diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py index ad05b0ee62e87ce9028dc043b693c1b4cae008b3..741c74251b31e18013c46b8e921632009d2bab62 100644 --- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py @@ -19,26 +19,27 @@ import grpc import numpy as np import onnxruntime as rt import logging -from common.proto.l3_centralizedattackdetector_pb2 import ( - Empty, -) -from common.proto.l3_centralizedattackdetector_pb2_grpc import ( - L3CentralizedattackdetectorServicer, -) - -from common.proto.l3_attackmitigator_pb2 import ( - L3AttackmitigatorOutput, -) -from common.proto.l3_attackmitigator_pb2_grpc import ( - L3AttackmitigatorStub, -) +from common.proto.l3_centralizedattackdetector_pb2 import Empty +from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorServicer + +from common.proto.l3_attackmitigator_pb2 import L3AttackmitigatorOutput +from common.proto.l3_attackmitigator_pb2_grpc import L3AttackmitigatorStub + +# KPIs and Monitoring +from common.proto.monitoring_pb2 import KpiDescriptor +from common.proto.kpi_sample_types_pb2 import KpiSampleType + +# from monitoring.client.MonitoringClient import MonitoringClient +from monitoring.client.MonitoringClient import MonitoringClient +from common.proto.monitoring_pb2 import Kpi +from common.proto.context_pb2 import Timestamp LOGGER = logging.getLogger(__name__) here = os.path.dirname(os.path.abspath(__file__)) MODEL_FILE = os.path.join(here, "ml_model/teraflow_rf.onnx") -class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer): +class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer): def __init__(self): LOGGER.debug("Creating Servicer...") self.inference_values = [] @@ -46,11 +47,39 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto self.input_name = self.model.get_inputs()[0].name self.label_name = self.model.get_outputs()[0].name self.prob_name = self.model.get_outputs()[1].name - + self.monitoring_client = MonitoringClient() + self.predicted_class_kpi_id = None + self.class_probability_kpi_id = None + + def create_predicted_class_kpi(self, client: MonitoringClient, service_id): + # create kpi + kpi_description: KpiDescriptor = KpiDescriptor() + kpi_description.kpi_description = "L3 security status of service {}".format(service_id) + # kpi_description.service_id.service_uuid.uuid = service_id + kpi_description.service_id.service_uuid.uuid = str(service_id) + kpi_description.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN + new_kpi = client.SetKpi(kpi_description) + + LOGGER.info("Created Predicted Class KPI {}...".format(new_kpi.kpi_id)) + + return new_kpi + + def create_class_prob_kpi(self, client: MonitoringClient, service_id): + # create kpi + kpi_description: KpiDescriptor = KpiDescriptor() + kpi_description.kpi_description = "L3 security status of service {}".format(service_id) + kpi_description.service_id.service_uuid.uuid = service_id + kpi_description.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN + new_kpi = client.SetKpi(kpi_description) + + LOGGER.info("Created Class Probability KPI {}...".format(new_kpi.kpi_id)) + + return new_kpi def make_inference(self, request): # ML MODEL - x_data = np.array([ + x_data = np.array( + [ [ request.n_packets_server_seconds, request.n_packets_client_seconds, @@ -61,10 +90,10 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto request.n_packets_server_n_packets_client, request.n_bits_server_n_bits_client, ] - ]) + ] + ) - predictions = self.model.run( - [self.prob_name], {self.input_name: x_data.astype(np.float32)})[0] + predictions = self.model.run([self.prob_name], {self.input_name: x_data.astype(np.float32)})[0] # Output format output_message = { "confidence": None, @@ -101,23 +130,46 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto # MAKE INFERENCE output = self.make_inference(request) + # Monitoring + service_id = request.service_id + + if self.predicted_class_kpi_id is None: + self.predicted_class_kpi_id = self.create_predicted_class_kpi(self.monitoring_client, service_id) + + if self.class_probability_kpi_id is None: + self.class_probability_kpi_id = self.create_class_prob_kpi(self.monitoring_client, service_id) + + # Packet -> DAD -> CAD -> ML -> (2 Instantaneous Value: higher class probability, predicted class) -> Monitoring + # In addition, two counters: + # Counter 1: Total number of crypto attack connections + # Counter 2: Rate of crypto attack connections with respect to the total number of connections + + kpi_class = Kpi() + kpi_class.kpi_id.kpi_id.uuid = self.predicted_class_kpi_id.uuid + kpi_class.kpi_value.int32Val = 1 if request.tag_name == "Crypto" else 0 + + kpi_prob = Kpi() + kpi_prob.kpi_id.kpi_id.uuid = self.class_probability_kpi_id.uuid + kpi_prob.kpi_value.floatVal = request.confidence + + kpi_class.timestamp = kpi_prob.timestamp = Timestamp() + + self.monitoring_client.IncludeKpi(kpi_class) + self.monitoring_client.IncludeKpi(kpi_prob) + # SEND INFO TO MITIGATION SERVER try: - with grpc.insecure_channel("localhost:10002") as channel: - stub = L3AttackmitigatorStub(channel) - print("Sending to mitigator...") - response = stub.SendOutput(output) - print("Sent output to mitigator and received: ", response.message) - - # RETURN "OK" TO THE CALLER - return Empty( - message="OK, information received and mitigator notified" - ) + with grpc.insecure_channel("localhost:10002") as channel: + stub = L3AttackmitigatorStub(channel) + print("Sending to mitigator...") + response = stub.SendOutput(output) + print("Sent output to mitigator and received: ", response.message) + + # RETURN "OK" TO THE CALLER + return Empty(message="OK, information received and mitigator notified") except: - print('Couldnt find l3_attackmitigator') - return Empty( - message="Mitigator Not found" - ) + print("Couldnt find l3_attackmitigator") + return Empty(message="Mitigator Not found") def GetOutput(self, request, context): logging.debug("") @@ -126,6 +178,3 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto k = np.sum(k) return self.make_inference(k) - - -