Commit a87e5ae2 authored by Amit Karamchandani Batra's avatar Amit Karamchandani Batra
Browse files

Rewrote all docstrings to comply with PEP 257 and added docstrings for all...

Rewrote all docstrings to comply with PEP 257 and added docstrings for all functions that did not have it in CAD and AM service implementations.
parent 2aad0738
Loading
Loading
Loading
Loading
+61 −8
Original line number Diff line number Diff line
@@ -34,12 +34,22 @@ from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_m

LOGGER = logging.getLogger(__name__)

METRICS_POOL = MetricsPool('l3_attackmitigator', 'RPC')
METRICS_POOL = MetricsPool("l3_attackmitigator", "RPC")


class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
    def __init__(self):
        LOGGER.info("Creating Attack Mitigator Service")
        """
        Initializes the Attack Mitigator service.

        Args:
            None.

        Returns:
            None.
        """

        LOGGER.info("Creating Attack Mitigator service")

        self.last_value = -1
        self.last_tag = 0
@@ -60,6 +70,23 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
        src_port: str,
        dst_port: str,
    ) -> None:
        """
        Configures an ACL rule to block undesired TCP traffic.

        Args:
            context_uuid (str): The UUID of the context.
            service_uuid (str): The UUID of the service.
            device_uuid (str): The UUID of the device.
            endpoint_uuid (str): The UUID of the endpoint.
            src_ip (str): The source IP address.
            dst_ip (str): The destination IP address.
            src_port (str): The source port.
            dst_port (str): The destination port.

        Returns:
            None.
        """

        # Create ServiceId
        service_id = ServiceId()
        service_id.context_id.context_uuid.uuid = context_uuid
@@ -123,10 +150,23 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def PerformMitigation(self, request, context):
        """
        Performs mitigation on an attack by configuring an ACL rule to block undesired TCP traffic.

        Args:
            request (L3AttackmitigatorOutput): The request message containing the attack mitigation information.
            context (Empty): The context of the request.

        Returns:
            Empty: An empty response indicating that the attack mitigation information was received and processed.
        """

        last_value = request.confidence
        last_tag = request.tag

        LOGGER.info(f"Attack Mitigator received attack mitigation information. Prediction confidence: {last_value}, Predicted class: {last_tag}")
        LOGGER.info(
            f"Attack Mitigator received attack mitigation information. Prediction confidence: {last_value}, Predicted class: {last_tag}"
        )

        ip_o = request.ip_o
        ip_d = request.ip_d
@@ -149,7 +189,9 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
                LOGGER.debug(f"Waiting 2 seconds for service to be available (attempt: {counter})")
                time.sleep(2)

        LOGGER.info(f"Service with Service Id.: {grpc_message_to_json_string(service_id)}\n{grpc_message_to_json_string(service)}")
        LOGGER.info(
            f"Service with Service Id.: {grpc_message_to_json_string(service_id)}\n{grpc_message_to_json_string(service)}"
        )
        LOGGER.info("Adding new rule to the service to block the attack")

        self.configure_acl_rule(
@@ -178,6 +220,17 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def GetConfiguredACLRules(self, request, context):
        """
        Returns the configured ACL rules.

        Args:
            request (Empty): The request message.
            context (Empty): The context of the RPC call.

        Returns:
            acl_rules (ACLRules): The configured ACL rules.
        """

        acl_rules = ACLRules()

        for acl_config_rule in self.configured_acl_config_rules:
+153 −61
Original line number Diff line number Diff line
@@ -65,12 +65,17 @@ class ConnectionInfo:


class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer):

    def __init__(self):
        """
    Initialize variables, prediction model and clients of components used by CAD
        Initializes the Centralized Attack Detector service.

        Args:
            None

        Returns:
            None
        """

    def __init__(self):
        LOGGER.info("Creating Centralized Attack Detector Service")

        self.inference_values = []
@@ -82,14 +87,14 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        )
        self.cryptomining_detector_model = rt.InferenceSession(self.cryptomining_detector_model_path)

        # Load cryptomining detector features metadata from ONNX file
        # Load cryptomining attack detector features metadata from ONNX file
        self.cryptomining_detector_features_metadata = list(
            self.cryptomining_detector_model.get_modelmeta().custom_metadata_map.values()
        )
        self.cryptomining_detector_features_metadata = [float(x) for x in self.cryptomining_detector_features_metadata]
        self.cryptomining_detector_features_metadata.sort()

        LOGGER.info(f"Cryptomining Detector Features: {self.cryptomining_detector_features_metadata}")
        LOGGER.info(f"Cryptomining Attack Detector Features: {self.cryptomining_detector_features_metadata}")
        LOGGER.info(f"Batch size: {BATCH_SIZE}")

        self.input_name = self.cryptomining_detector_model.get_inputs()[0].name
@@ -121,7 +126,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        self.monitored_kpis = {
            "l3_security_status": {
                "kpi_id": None,
                "description": "L3 - Confidence of the cryptomining detector in the security status in the last time interval of the service {service_id}",
                "description": "L3 - Confidence of the cryptomining attack detector in the security status in the last time interval of the service {service_id}",
                "kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_SECURITY_STATUS_CRYPTO,
                "service_ids": [],
            },
@@ -193,16 +198,6 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
            writer = csv.writer(file)
            writer.writerow(col_names)

    """
    Create a monitored KPI for a specific service and add it to the Monitoring Client
        -input: 
            + service_id: service ID where the KPI will be monitored
            + kpi_name: name of the KPI
            + kpi_description: description of the KPI
            + kpi_sample_type: KPI sample type of the KPI (it must be defined in the kpi_sample_types.proto file)
        -output: KPI identifier representing the KPI
    """

    def create_kpi(
        self,
        service_id,
@@ -210,24 +205,40 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        kpi_description,
        kpi_sample_type,
    ):
        """
        Creates a new KPI for a specific service and add it to the Monitoring client

        Args:
            service_id (ServiceID): The ID of the service.
            kpi_name (str): The name of the KPI.
            kpi_description (str): The description of the KPI.
            kpi_sample_type (KpiSampleType): The sample type of the KPI.

        Returns:
            kpi (Kpi): The created KPI.
        """

        kpidescriptor = KpiDescriptor()
        kpidescriptor.kpi_description = kpi_description
        kpidescriptor.service_id.service_uuid.uuid = service_id.service_uuid.uuid
        kpidescriptor.kpi_sample_type = kpi_sample_type
        new_kpi = self.monitoring_client.SetKpi(kpidescriptor)
        kpi = self.monitoring_client.SetKpi(kpidescriptor)

        LOGGER.info("Created KPI {}".format(kpi_name))

        return new_kpi
        return kpi

    def create_kpis(self, service_id):
        """
    Create the monitored KPIs for a specific service, add them to the Monitoring Client and store their identifiers in the monitored_kpis dictionary
        -input:
            + service_id: service ID where the KPIs will be monitored
        -output: None
        Creates the monitored KPIs for a specific service, adds them to the Monitoring client and stores their identifiers in the monitored_kpis dictionary

        Args:
            service_id (uuid): The ID of the service.

        Returns:
            None
        """

    def create_kpis(self, service_id):
        LOGGER.info("Creating KPIs for service {}".format(service_id))

        # all the KPIs are created for all the services from which requests are received
@@ -244,6 +255,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        LOGGER.info("Created KPIs for service {}".format(service_id))

    def monitor_kpis(self):
        """
        Monitors KPIs for all the services from which requests are received

        Args:
            None

        Returns:
            None
        """

        monitor_inference_results = self.inference_results
        monitor_service_ids = self.service_ids

@@ -262,6 +283,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
            LOGGER.debug("No KPIs sent to monitoring server")

    def assign_timestamp(self, monitor_inference_results):
        """
        Assigns a timestamp to the monitored inference results.

        Args:
            monitor_inference_results (list): A list of monitored inference results.

        Returns:
            None
        """

        time_interval = self.MONITORED_KPIS_TIME_INTERVAL_AGG

        # assign the timestamp of the first inference result to the time_interval_start
@@ -304,6 +335,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
    def monitor_compute_l3_kpi(
        self,
    ):
        """
        Computes the monitored KPIs for a specific service and sends them to the Monitoring server

        Args:
            None

        Returns:
            None
        """

        # L3 security status
        kpi_security_status = Kpi()
        kpi_security_status.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_security_status"]["kpi_id"])
@@ -358,19 +399,36 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
            LOGGER.debug("Error sending KPIs to monitoring server: {}".format(e))

    def monitor_ml_model_confidence(self):
        """
        Get the monitored KPI for the confidence of the ML model

        Args:
            None

        Returns:
            confidence (float): The monitored KPI for the confidence of the ML model
        """

        confidence = None

        if self.l3_security_status == 0:
            return self.l3_ml_model_confidence_normal
            confidence = self.l3_ml_model_confidence_normal
        else:
            confidence = self.l3_ml_model_confidence_crypto

        return self.l3_ml_model_confidence_crypto
        return confidence

    def perform_inference(self, request):
        """
    Classify connection as standard traffic or cryptomining attack and return results
        -input: 
            + request: L3CentralizedattackdetectorMetrics object with connection features information
        -output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence
        Performs inference on the input data using the Cryptomining Attack Detector model to classify the connection as standard traffic or cryptomining attack.

        Args:
            request (L3CentralizedattackdetectorMetrics): A L3CentralizedattackdetectorMetrics object with connection features information.

        Returns:
            dict: A dictionary containing the predicted class, the probability of that class, and other relevant information required to block the attack.
        """

    def perform_inference(self, request):
        x_data = np.array([[feature.feature for feature in request.features]])

        # Print input data shape
@@ -444,14 +502,17 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto

        return output_message

    def perform_batch_inference(self, requests):
        """
    Classify connection as standard traffic or cryptomining attack and return results
        -input: 
            + request: L3CentralizedattackdetectorMetrics object with connection features information
        -output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence
        Performs batch inference on the input data using the Cryptomining Attack Detector model to classify the connection as standard traffic or cryptomining attack.

        Args:
            requests (list): A list of L3CentralizedattackdetectorMetrics objects with connection features information.

        Returns:
            list: A list of dictionaries containing the predicted class, the probability of that class, and other relevant information required to block the attack for each request.
        """

    def perform_distributed_inference(self, requests):
        batch_size = len(requests)

        # Create an empty array to hold the input data
@@ -534,15 +595,19 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto

        return output_messages

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def AnalyzeConnectionStatistics(self, request, context):
        """
    Receive features from Attack Mitigator, predict attack and communicate with Attack Mitigator
        -input: 
            + request: L3CentralizedattackdetectorMetrics object with connection features information
        -output: Empty object with a message about the execution of the function
        Analyzes the connection statistics sent in the request, performs batch inference on the input data using the Cryptomining Attack Detector model to classify the connection as standard traffic or cryptomining attack, and notifies the Attack Mitigator component in case of attack.

        Args:
            request (L3CentralizedattackdetectorMetrics): A L3CentralizedattackdetectorMetrics object with connection features information.
            context (Empty): The context of the request.

        Returns:
            Empty: An empty response indicating that the information was received and processed.
        """

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def AnalyzeConnectionStatistics(self, request, context):
        # Perform inference with the data sent in the request
        if len(self.active_requests) == 0:
            self.first_batch_request_time = time.time()
@@ -553,7 +618,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
            LOGGER.debug("Performing inference... {}".format(self.pod_id))

            inference_time_start = time.time()
            cryptomining_detector_output = self.perform_distributed_inference(self.active_requests)
            cryptomining_detector_output = self.perform_batch_inference(self.active_requests)
            inference_time_end = time.time()

            LOGGER.debug("Inference performed in {} seconds".format(inference_time_end - inference_time_start))
@@ -711,6 +776,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        return Empty(message="Ok, information received")

    def analyze_prediction_accuracy(self, confidence):
        """
        Analyzes the prediction accuracy of the Centralized Attack Detector.

        Args:
            confidence (float): The confidence level of the Cryptomining Attack Detector model.

        Returns:
            None
        """

        LOGGER.info("Number of Attack Connections Correctly Classified: {}".format(self.correct_attack_conns))
        LOGGER.info("Number of Attack Connections: {}".format(len(self.attack_connections)))

@@ -727,7 +802,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
            cryptomining_attack_detection_acc = 0

        LOGGER.info("Cryptomining Attack Detection Accuracy: {}".format(cryptomining_attack_detection_acc))
        LOGGER.info("Cryptomining Detector Confidence: {}".format(confidence))
        LOGGER.info("Cryptomining Attack Detector Confidence: {}".format(confidence))

        with open("prediction_accuracy.txt", "a") as f:
            LOGGER.debug("Exporting prediction accuracy and confidence")
@@ -739,12 +814,23 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
            f.write("False Positives: {}\n".format(self.false_positives))
            f.write("True Negatives: {}\n".format(self.total_predictions - len(self.attack_connections)))
            f.write("False Negatives: {}\n".format(self.false_negatives))
            f.write("Cryptomining Detector Confidence: {}\n\n".format(confidence))
            f.write("Cryptomining Attack Detector Confidence: {}\n\n".format(confidence))
            f.write("Timestamp: {}\n".format(datetime.now().strftime("%d/%m/%Y %H:%M:%S")))
            f.close()

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def AnalyzeBatchConnectionStatistics(self, request, context):
        """
        Analyzes a batch of connection statistics sent in the request, performs batch inference on the input data using the Cryptomining Attack Detector model to classify the connection as standard traffic or cryptomining attack, and notifies the Attack Mitigator component in case of attack.

        Args:
            request (L3CentralizedattackdetectorBatchMetrics): A L3CentralizedattackdetectorBatchMetrics object with connection features information.
            context (Empty): The context of the request.

        Returns:
            Empty: An empty response indicating that the information was received and processed.
        """

        batch_time_start = time.time()

        for metric in request.metrics:
@@ -761,16 +847,22 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto

        return Empty(message="OK, information received.")

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def GetFeaturesIds(self, request, context):
        """
    Send features allocated in the metadata of the onnx file to the DAD
        -output: ONNX metadata as a list of integers
        Returns a list of feature IDs used by the Cryptomining Attack Detector model.

        Args:
            request (Empty): An empty request object.
            context (Empty): The context of the request.

        Returns:
            features_ids (AutoFeatures): A list of feature IDs used by the Cryptomining Attack Detector model.
        """

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def GetFeaturesIds(self, request: Empty, context):
        features = AutoFeatures()
        features_ids = AutoFeatures()

        for feature in self.cryptomining_detector_features_metadata:
            features.auto_features.append(feature)
            features_ids.auto_features.append(feature)

        return features
        return features_ids