Commit 5ae60326 authored by Lluis Gifre Renom's avatar Lluis Gifre Renom
Browse files

Merge branch 'develop' of https://labs.etsi.org/rep/tfs/controller into feat/service-location

parents e2391fa5 579ef5ea
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -36,6 +36,12 @@ spec:
        env:
        - name: LOG_LEVEL
          value: "DEBUG"
        - name: BATCH_SIZE
          value: "256"
        - name: CAD_CLASSIFICATION_THRESHOLD
          value: "0.5"
        - name: MONITORED_KPIS_TIME_INTERVAL_AGG
          value: "60"
        readinessProbe:
          exec:
            command: ["/bin/grpc_health_probe", "-addr=:10001"]
+10 −17
Original line number Diff line number Diff line
@@ -65,9 +65,6 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
        service_id.context_id.context_uuid.uuid = context_uuid
        service_id.service_uuid.uuid = service_uuid

        # Get service form Context
        # context_client = ContextClient()

        try:
            _service: Service = self.context_client.GetService(service_id)
        except:
@@ -88,11 +85,9 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):

        # Set RuleSet for this ACL ConfigRule
        acl_rule_set = acl_config_rule.acl.rule_set
        # TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule
        # acl_rule_set.name = "DROP-HTTPS"

        acl_rule_set.name = "DROP-TCP"
        acl_rule_set.type = AclRuleTypeEnum.ACLRULETYPE_IPV4
        # acl_rule_set.description = "DROP undesired HTTPS traffic"
        acl_rule_set.description = "DROP undesired TCP traffic"

        # Add ACLEntry to the ACLRuleSet
@@ -108,26 +103,24 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
        acl_entry.match.dst_address = "{}/32".format(dst_ip)
        acl_entry.match.src_port = int(src_port)
        acl_entry.match.dst_port = int(dst_port)
        # TODO: update the following parameters; for instance, add them as parameters of the method configure_acl_rule

        acl_entry.action.forward_action = AclForwardActionEnum.ACLFORWARDINGACTION_DROP
        acl_entry.action.log_action = AclLogActionEnum.ACLLOGACTION_NOLOG

        LOGGER.info("ACL Rule Set: %s", acl_rule_set)
        LOGGER.info("ACL Config Rule: %s", acl_config_rule)
        LOGGER.info("ACL Rule Set: %s", grpc_message_to_json_string(acl_rule_set))
        LOGGER.info("ACL Config Rule: %s", grpc_message_to_json_string(acl_config_rule))

        # Add the ACLRuleSet to the list of configured ACLRuleSets
        self.configured_acl_config_rules.append(acl_config_rule)

        # Update the Service with the new ACL RuleSet
        # service_client = ServiceClient()
        service_reply: ServiceId = self.service_client.UpdateService(service_request)

        # TODO: Log the service_reply details
        LOGGER.info("Service reply: %s", grpc_message_to_json_string(service_reply))

        if service_reply != service_request.service_id:  # pylint: disable=no-member
            raise Exception("Service update failed. Wrong ServiceId was returned")


    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def PerformMitigation(self, request, context):
        last_value = request.confidence
@@ -148,7 +141,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
        counter = 0
        service_id = request.service_id

        LOGGER.info("Service Id.:\n{}".format(service_id))
        LOGGER.info("Service Id.:\n{}".format(grpc_message_to_json_string(service_id)))

        LOGGER.info("Retrieving service from Context")
        while sentinel:
@@ -160,7 +153,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
                LOGGER.debug("Waiting 2 seconds", counter, e)
                time.sleep(2)

        LOGGER.info(f"Service with Service Id.: {service_id}\n{service}")
        LOGGER.info(f"Service with Service Id.: {grpc_message_to_json_string(service_id)}\n{grpc_message_to_json_string(service)}")

        LOGGER.info("Adding new rule to the service to block the attack")
        self.configure_acl_rule(
@@ -173,20 +166,20 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
            src_port=port_o,
            dst_port=port_d,
        )
        LOGGER.info("Service with new rule:\n{}".format(service))
        LOGGER.info("Service with new rule:\n{}".format(grpc_message_to_json_string(service)))

        LOGGER.info("Updating service with the new rule")
        self.service_client.UpdateService(service)
        service = self.context_client.GetService(service_id)

        LOGGER.info(
            "Service obtained from Context after updating with the new rule:\n{}".format(
                self.context_client.GetService(service_id)
                grpc_message_to_json_string(service)
            )
        )

        return Empty(message=f"OK, received values: {last_tag} with confidence {last_value}.")


    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def GetConfiguredACLRules(self, request, context):
        acl_rules = ACLRules()
+95 −96
Original line number Diff line number Diff line
@@ -13,46 +13,36 @@
# limitations under the License.

from __future__ import print_function
from datetime import datetime
from datetime import timedelta
from datetime import datetime, timedelta

import csv
import os
import numpy as np
import onnxruntime as rt
import logging
import time
import uuid

from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method
from common.proto.context_pb2 import Timestamp, SliceId, ConnectionId
from common.proto.kpi_sample_types_pb2 import KpiSampleType
from common.proto.l3_attackmitigator_pb2 import L3AttackmitigatorOutput
from common.proto.l3_centralizedattackdetector_pb2 import Empty, AutoFeatures
from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorServicer

from common.proto.l3_attackmitigator_pb2 import L3AttackmitigatorOutput

from common.proto.monitoring_pb2 import KpiDescriptor
from common.proto.kpi_sample_types_pb2 import KpiSampleType

from monitoring.client.MonitoringClient import MonitoringClient
from common.proto.monitoring_pb2 import Kpi

from common.proto.monitoring_pb2 import Kpi, KpiDescriptor
from common.tools.timestamp.Converters import timestamp_utcnow_to_float
from common.proto.context_pb2 import Timestamp, SliceId, ConnectionId

from monitoring.client.MonitoringClient import MonitoringClient
from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigatorClient

import uuid

from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method


LOGGER = logging.getLogger(__name__)
current_dir = os.path.dirname(os.path.abspath(__file__))

# Demo constants
# Constants
DEMO_MODE = False
ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"]

BATCH_SIZE= 10

METRICS_POOL = MetricsPool('l3_centralizedattackdetector', 'RPC')
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 10))
METRICS_POOL = MetricsPool("l3_centralizedattackdetector", "RPC")


class ConnectionInfo:
@@ -100,14 +90,14 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        self.cryptomining_detector_features_metadata.sort()
        LOGGER.info("Cryptomining Detector Features: " + str(self.cryptomining_detector_features_metadata))

        LOGGER.info("Batch size: " + str(BATCH_SIZE))
        LOGGER.info(f"Batch size: {BATCH_SIZE}")

        self.input_name = self.cryptomining_detector_model.get_inputs()[0].name
        self.label_name = self.cryptomining_detector_model.get_outputs()[0].name
        self.prob_name = self.cryptomining_detector_model.get_outputs()[1].name

        # Kpi values
        self.l3_security_status = 0  # unnecessary
        # KPI values
        self.l3_security_status = 0
        self.l3_ml_model_confidence = 0
        self.l3_inferences_in_interval_counter = 0

@@ -163,8 +153,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        self.attackmitigator_client = l3_attackmitigatorClient()

        # Environment variables
        self.CLASSIFICATION_THRESHOLD = os.getenv("CAD_CLASSIFICATION_THRESHOLD", 0.5)
        self.MONITORED_KPIS_TIME_INTERVAL_AGG = os.getenv("MONITORED_KPIS_TIME_INTERVAL_AGG", 60)
        self.CLASSIFICATION_THRESHOLD = float(os.getenv("CAD_CLASSIFICATION_THRESHOLD", 0.5))
        self.MONITORED_KPIS_TIME_INTERVAL_AGG = int(os.getenv("MONITORED_KPIS_TIME_INTERVAL_AGG", 60))

        # Constants
        self.NORMAL_CLASS = 0
@@ -190,6 +180,20 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        self.false_positives = 0
        self.false_negatives = 0

        self.replica_uuid = uuid.uuid4()

        self.first_batch_request_time = 0
        self.last_batch_request_time = 0

        LOGGER.info("This replica's identifier is: " + str(self.replica_uuid))

        self.response_times_csv_file_path = "response_times.csv"
        col_names = ["timestamp_first_req", "timestamp_last_req", "total_time", "batch_size"]

        with open(self.response_times_csv_file_path, "w", newline="") as file:
            writer = csv.writer(file)
            writer.writerow(col_names)

    """
    Create a monitored KPI for a specific service and add it to the Monitoring Client
        -input: 
@@ -224,19 +228,11 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        -output: None
    """

    def create_kpis(self, service_id, device_id, endpoint_id):
    def create_kpis(self, service_id):
        LOGGER.info("Creating KPIs for service {}".format(service_id))

        # for now, all the KPIs are created for all the services from which requests are received
        # all the KPIs are created for all the services from which requests are received
        for kpi in self.monitored_kpis:
            # generate random slice_id
            slice_id = SliceId()
            slice_id.slice_uuid.uuid = str(uuid.uuid4())

            # generate random connection_id
            connection_id = ConnectionId()
            connection_id.connection_uuid.uuid = str(uuid.uuid4())

            created_kpi = self.create_kpi(
                service_id,
                kpi,
@@ -262,14 +258,6 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                
                self.monitor_compute_l3_kpi(service_id, monitor_inference_results)
                
                # Demo mode inference results are erased
                """if DEMO_MODE:
                    # Delete fist half of the inference results
                    LOGGER.debug("inference_results len: {}".format(len(self.inference_results)))
                    self.inference_results = self.inference_results[len(self.inference_results)//2:]
                    LOGGER.debug("inference_results len after erase: {}".format(len(self.inference_results)))"""
                # end = time.time()
                # LOGGER.debug("Time to process inference results with erase: {}".format(end - start))
                LOGGER.debug("KPIs sent to monitoring server")
        else:
            LOGGER.debug("No KPIs sent to monitoring server")
@@ -314,7 +302,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        LOGGER.debug("time_interval_start: {}".format(self.time_interval_start))
        LOGGER.debug("time_interval_end: {}".format(self.time_interval_end))

    def monitor_compute_l3_kpi(self, service_id, monitor_inference_results):
    def monitor_compute_l3_kpi(self,):
        # L3 security status
        kpi_security_status = Kpi()
        kpi_security_status.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_security_status"]["kpi_id"])
@@ -389,19 +377,14 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto

        # Get batch size
        batch_size = x_data.shape[0]

        # Print batch size
        LOGGER.debug("batch_size: {}".format(batch_size))
        LOGGER.debug("x_data.shape: {}".format(x_data.shape))

        inference_time_start = time.perf_counter()
        inference_time_start = time.time()

        # Perform inference
        predictions = self.cryptomining_detector_model.run(
            [self.prob_name], {self.input_name: x_data.astype(np.float32)}
        )[0]

        inference_time_end = time.perf_counter()
        inference_time_end = time.time()

        # Measure inference time
        inference_time = inference_time_end - inference_time_start
@@ -480,14 +463,14 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        # Print input data shape
        LOGGER.debug("x_data.shape: {}".format(x_data.shape))

        inference_time_start = time.perf_counter()
        inference_time_start = time.time()

        # Perform inference
        predictions = self.cryptomining_detector_model.run(
            [self.prob_name], {self.input_name: x_data.astype(np.float32)}
        )[0]

        inference_time_end = time.perf_counter()
        inference_time_end = time.time()

        # Measure inference time
        inference_time = inference_time_end - inference_time_start
@@ -519,7 +502,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        # Gather the predicted class, the probability of that class and other relevant information required to block the attack
        output_messages = []
        for i, request in enumerate(requests):
            output_messages.append({
            output_messages.append(
                {
                    "confidence": None,
                    "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
                    "ip_o": request.connection_metadata.ip_o,
@@ -535,7 +519,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                    "endpoint_id": request.connection_metadata.endpoint_id,
                    "time_start": request.connection_metadata.time_start,
                    "time_end": request.connection_metadata.time_end,
            })
                }
            )

            if predictions[i][1] >= self.CLASSIFICATION_THRESHOLD:
                output_messages[i]["confidence"] = predictions[i][1]
@@ -554,13 +539,17 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
            + request: L3CentralizedattackdetectorMetrics object with connection features information
        -output: Empty object with a message about the execution of the function
    """

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def AnalyzeConnectionStatistics(self, request, context):
        # Perform inference with the data sent in the request
        if len(self.active_requests) == 0:
            self.first_batch_request_time = time.time()

        self.active_requests.append(request)

        if len(self.active_requests) == BATCH_SIZE:
            logging.info("Performing inference...")
        if len(self.active_requests) >= BATCH_SIZE:
            LOGGER.debug("Performing inference... {}".format(self.replica_uuid))

            inference_time_start = time.time()
            cryptomining_detector_output = self.perform_distributed_inference(self.active_requests)
@@ -574,12 +563,10 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto

            for i, req in enumerate(self.active_requests):
                service_id = req.connection_metadata.service_id
                device_id = req.connection_metadata.endpoint_id.device_id
                endpoint_id = req.connection_metadata.endpoint_id

                # Check if a request of a new service has been received and, if so, create the monitored KPIs for that service
                if service_id not in self.service_ids:
                    self.create_kpis(service_id, device_id, endpoint_id)
                    self.create_kpis(service_id)
                    self.service_ids.append(service_id)

                monitor_kpis_start = time.time()
@@ -637,9 +624,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                        self.false_positives += 1

                    self.total_predictions += 1

                    # if False:
                    notification_time_start = time.perf_counter()
                    notification_time_start = time.time()

                    LOGGER.debug("Crypto attack detected")

@@ -651,8 +636,11 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                    try:
                        logging.info("Sending the connection information to the Attack Mitigator component...")
                        message = L3AttackmitigatorOutput(**cryptomining_detector_output[i])
                        response = self.attackmitigator_client.PerformMitigation(message)
                        notification_time_end = time.perf_counter()
                        
                        am_response = self.attackmitigator_client.PerformMitigation(message)
                        LOGGER.debug("AM response: {}".format(am_response))
                        
                        notification_time_end = time.time()

                        self.am_notification_times.append(notification_time_end - notification_time_start)

@@ -682,11 +670,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                                f.write("Std notification time: {}\n".format(std_notification_time))
                                f.write("Median notification time: {}\n".format(median_notification_time))

                        # logging.info("Attack Mitigator notified and received response: ", response.message)  # FIX No message received
                        logging.info("Attack Mitigator notified")

                        #return Empty(message="OK, information received and mitigator notified abou the attack")
                    
                    except Exception as e:
                        logging.error("Error notifying the Attack Mitigator component about the attack: ", e)
                        logging.error("Couldn't find l3_attackmitigator")
@@ -704,9 +689,22 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto

                        self.total_predictions += 1

                    # return Empty(message="Ok, information received (no attack detected)")
            
            self.active_requests = []
            self.last_batch_request_time = time.time()

            col_values = [
                self.first_batch_request_time,
                self.last_batch_request_time,
                self.last_batch_request_time - self.first_batch_request_time,
                BATCH_SIZE,
            ]

            LOGGER.debug("col_values: {}".format(col_values))

            with open(self.response_times_csv_file_path, "a", newline="") as file:
                writer = csv.writer(file)
                writer.writerow(col_values)

            return Empty(message="Ok, metrics processed")

        return Empty(message="Ok, information received")
@@ -766,6 +764,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
    Send features allocated in the metadata of the onnx file to the DAD
        -output: ONNX metadata as a list of integers
    """

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def GetFeaturesIds(self, request: Empty, context):
        features = AutoFeatures()
+0 −791

File deleted.

Preview size limit exceeded, changes collapsed.

+9 −3
Original line number Diff line number Diff line
# Scripts to automatically run the "Attack Detection & Mitigation at the L3 Layer" workflow (Scenario 3).
"launch_l3_attack_detection_and_mitigation.sh" launches the TeraFlow OS components, which includes the CentralizedAttackDetector and AttackMitigator componentes necessary to perform this workflow.
"launch_l3_attack_detection_and_mitigation_complete.sh" also launches the DistributedAttackDetector, which monitors the network data plane and passively collects traffic packets and aggregates them in network flows, which are then provided to the CentralizedAttackDetector to detect attacks that may be occurring in the network.
# Demonstration of a L3 Cybersecurity Components for Attack Detection and Mitigation

__Authors__: Partners of Universidad Politécnica de Madrid and Telefónica I+D

## Executing

```bash
python src/tests/scenario3/l3/run.sh
```
 No newline at end of file
Loading