Commit 2aad0738 authored by Amit Karamchandani Batra's avatar Amit Karamchandani Batra
Browse files

Homogenized string formatting and logger usage in CAD and AM service implementation

parent 0b4e791e
Loading
Loading
Loading
Loading
+10 −14
Original line number Diff line number Diff line
@@ -107,8 +107,8 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
        acl_entry.action.forward_action = AclForwardActionEnum.ACLFORWARDINGACTION_DROP
        acl_entry.action.log_action = AclLogActionEnum.ACLLOGACTION_NOLOG

        LOGGER.info("ACL Rule Set: %s", grpc_message_to_json_string(acl_rule_set))
        LOGGER.info("ACL Config Rule: %s", grpc_message_to_json_string(acl_config_rule))
        LOGGER.info(f"ACL Rule Set: {grpc_message_to_json_string(acl_rule_set)}")
        LOGGER.info(f"ACL Config Rule: {grpc_message_to_json_string(acl_config_rule)}")

        # Add the ACLRuleSet to the list of configured ACLRuleSets
        self.configured_acl_config_rules.append(acl_config_rule)
@@ -116,7 +116,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
        # Update the Service with the new ACL RuleSet
        service_reply: ServiceId = self.service_client.UpdateService(service_request)
        
        LOGGER.info("Service reply: %s", grpc_message_to_json_string(service_reply))
        LOGGER.info(f"Service reply: {grpc_message_to_json_string(service_reply)}")

        if service_reply != service_request.service_id:  # pylint: disable=no-member
            raise Exception("Service update failed. Wrong ServiceId was returned")
@@ -126,11 +126,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
        last_value = request.confidence
        last_tag = request.tag

        LOGGER.info(
            "Attack Mitigator received attack mitigation information. Prediction confidence: %s, Predicted class: %s",
            last_value,
            last_tag,
        )
        LOGGER.info(f"Attack Mitigator received attack mitigation information. Prediction confidence: {last_value}, Predicted class: {last_tag}")

        ip_o = request.ip_o
        ip_d = request.ip_d
@@ -141,21 +137,21 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
        counter = 0
        service_id = request.service_id

        LOGGER.info("Service Id.:\n{}".format(grpc_message_to_json_string(service_id)))

        LOGGER.info(f"Service Id.: {grpc_message_to_json_string(service_id)}")
        LOGGER.info("Retrieving service from Context")
        
        while sentinel:
            try:
                service = self.context_client.GetService(service_id)
                sentinel = False
            except Exception as e:
                counter = counter + 1
                LOGGER.debug("Waiting 2 seconds", counter, e)
                LOGGER.debug(f"Waiting 2 seconds for service to be available (attempt: {counter})")
                time.sleep(2)

        LOGGER.info(f"Service with Service Id.: {grpc_message_to_json_string(service_id)}\n{grpc_message_to_json_string(service)}")

        LOGGER.info("Adding new rule to the service to block the attack")
        
        self.configure_acl_rule(
            context_uuid=service_id.context_id.context_uuid.uuid,
            service_uuid=service_id.service_uuid.uuid,
@@ -167,8 +163,8 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
            dst_port=port_d,
        )
        LOGGER.info("Service with new rule:\n{}".format(grpc_message_to_json_string(service)))

        LOGGER.info("Updating service with the new rule")
        
        self.service_client.UpdateService(service)
        service = self.context_client.GetService(service_id)

+23 −22
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ class ConnectionInfo:
        )

    def __str__(self):
        return "ip_o: " + self.ip_o + "\nport_o: " + self.port_o + "\nip_d: " + self.ip_d + "\nport_d: " + self.port_d
        return f"ip_o: {self.ip_o}\nport_o: {self.port_o}\nip_d: {self.ip_d}\nport_d: {self.port_d}"


class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer):
@@ -88,8 +88,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        )
        self.cryptomining_detector_features_metadata = [float(x) for x in self.cryptomining_detector_features_metadata]
        self.cryptomining_detector_features_metadata.sort()
        LOGGER.info("Cryptomining Detector Features: " + str(self.cryptomining_detector_features_metadata))

        LOGGER.info(f"Cryptomining Detector Features: {self.cryptomining_detector_features_metadata}")
        LOGGER.info(f"Batch size: {BATCH_SIZE}")

        self.input_name = self.cryptomining_detector_model.get_inputs()[0].name
@@ -180,13 +180,12 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        self.false_positives = 0
        self.false_negatives = 0

        self.replica_uuid = uuid.uuid4()
        self.pod_id = uuid.uuid4()
        LOGGER.info(f"Pod Id.: {self.pod_id}")

        self.first_batch_request_time = 0
        self.last_batch_request_time = 0

        LOGGER.info("This replica's identifier is: " + str(self.replica_uuid))

        self.response_times_csv_file_path = "response_times.csv"
        col_names = ["timestamp_first_req", "timestamp_last_req", "total_time", "batch_size"]

@@ -302,7 +301,9 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        LOGGER.debug("time_interval_start: {}".format(self.time_interval_start))
        LOGGER.debug("time_interval_end: {}".format(self.time_interval_end))

    def monitor_compute_l3_kpi(self,):
    def monitor_compute_l3_kpi(
        self,
    ):
        # L3 security status
        kpi_security_status = Kpi()
        kpi_security_status.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_security_status"]["kpi_id"])
@@ -549,14 +550,14 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        self.active_requests.append(request)

        if len(self.active_requests) >= BATCH_SIZE:
            LOGGER.debug("Performing inference... {}".format(self.replica_uuid))
            LOGGER.debug("Performing inference... {}".format(self.pod_id))

            inference_time_start = time.time()
            cryptomining_detector_output = self.perform_distributed_inference(self.active_requests)
            inference_time_end = time.time()

            LOGGER.debug("Inference performed in {} seconds".format(inference_time_end - inference_time_start))
            logging.info("Inference performed correctly")
            LOGGER.info("Inference performed correctly")

            self.inference_results.append({"output": cryptomining_detector_output, "timestamp": datetime.now()})
            LOGGER.debug("inference_results length: {}".format(len(self.inference_results)))
@@ -629,12 +630,12 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                    LOGGER.debug("Crypto attack detected")

                    # Notify the Attack Mitigator component about the attack
                    logging.info(
                    LOGGER.info(
                        "Notifying the Attack Mitigator component about the attack in order to block the connection..."
                    )

                    try:
                        logging.info("Sending the connection information to the Attack Mitigator component...")
                        LOGGER.info("Sending the connection information to the Attack Mitigator component...")
                        message = L3AttackmitigatorOutput(**cryptomining_detector_output[i])

                        am_response = self.attackmitigator_client.PerformMitigation(message)
@@ -670,15 +671,15 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                                f.write("Std notification time: {}\n".format(std_notification_time))
                                f.write("Median notification time: {}\n".format(median_notification_time))

                        logging.info("Attack Mitigator notified")
                        LOGGER.info("Attack Mitigator notified")

                    except Exception as e:
                        logging.error("Error notifying the Attack Mitigator component about the attack: ", e)
                        logging.error("Couldn't find l3_attackmitigator")
                        LOGGER.error("Error notifying the Attack Mitigator component about the attack: ", e)
                        LOGGER.error("Couldn't find l3_attackmitigator")

                        return Empty(message="Attack Mitigator not found")
                else:
                    logging.info("No attack detected")
                    LOGGER.info("No attack detected")

                    if cryptomining_detector_output[i]["tag_name"] != "Crypto":
                        if connection_info.ip_o not in ATTACK_IPS and connection_info.ip_d not in ATTACK_IPS:
@@ -751,12 +752,12 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        batch_time_end = time.time()

        with open("batch_time.txt", "a") as f:
            f.write(str(len(request.metrics)) + "\n")
            f.write(str(batch_time_end - batch_time_start) + "\n\n")
            f.write(f"{len(request.metrics)}\n")
            f.write(f"{batch_time_end - batch_time_start}\n\n")
            f.close()

        logging.debug("Metrics: " + str(len(request.metrics)))
        logging.debug("Batch time: " + str(batch_time_end - batch_time_start))
        LOGGER.debug(f"Batch time: {batch_time_end - batch_time_start}")
        LOGGER.debug("Batch time: {}".format(batch_time_end - batch_time_start))

        return Empty(message="OK, information received.")