From 651b180edfba3d6a6c659d84d97fcd9af1243a0d Mon Sep 17 00:00:00 2001
From: karamchandan <amit.kbatra@alumnos.upm.es>
Date: Mon, 10 Jul 2023 03:25:55 +0200
Subject: [PATCH] Refactored code to enhance testing capabilities: - Eliminated
 hardcoded attack connection IPs and introduced a more flexible approach. -
 Incorporated a new RPC method and protobuf message in the CAD proto to
 facilitate setting the attack connection IPs by an external component to
 enable computation of the ML model's performance for testing purposes. -
 Introduced a new environment variable in CAD to enable or disable testing of
 the ML model accuracy.

---
 .../l3_centralizedattackdetectorservice.yaml  |  2 ++
 proto/l3_centralizedattackdetector.proto      |  7 +++++
 .../l3_attackmitigatorServiceServicerImpl.py  |  1 -
 ...alizedattackdetectorServiceServicerImpl.py | 31 ++++++++++++++-----
 4 files changed, 33 insertions(+), 8 deletions(-)

diff --git a/manifests/l3_centralizedattackdetectorservice.yaml b/manifests/l3_centralizedattackdetectorservice.yaml
index 95c6d8176..8a3be69b6 100644
--- a/manifests/l3_centralizedattackdetectorservice.yaml
+++ b/manifests/l3_centralizedattackdetectorservice.yaml
@@ -42,6 +42,8 @@ spec:
           value: "0.5"
         - name: MONITORED_KPIS_TIME_INTERVAL_AGG
           value: "60"
+        - name: TEST_ML_MODEL
+          value: "0"
         readinessProbe:
           exec:
             command: ["/bin/grpc_health_probe", "-addr=:10001"]
diff --git a/proto/l3_centralizedattackdetector.proto b/proto/l3_centralizedattackdetector.proto
index ed99435aa..56273cb62 100644
--- a/proto/l3_centralizedattackdetector.proto
+++ b/proto/l3_centralizedattackdetector.proto
@@ -25,6 +25,9 @@ service L3Centralizedattackdetector {
 
   // Get the list of features used by the ML model in the CAD component
   rpc GetFeaturesIds (Empty) returns (AutoFeatures) {}
+
+  // Sets the list of attack IPs in order to be used to compute the prediction accuracy of the ML model in the CAD component in case of testing the ML model
+  rpc SetAttackIPs (AttackIPs) returns (Empty) {}
 }
 
 message Feature {
@@ -66,3 +69,7 @@ message L3CentralizedattackdetectorBatchInput {
 message Empty {
 	string message = 1;
 }
+
+message AttackIPs {
+	repeated string attack_ips = 1;
+}
\ No newline at end of file
diff --git a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py
index c1ee4b3bf..ad02f6243 100644
--- a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py
+++ b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py
@@ -33,7 +33,6 @@ from service.client.ServiceClient import ServiceClient
 from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method
 
 LOGGER = logging.getLogger(__name__)
-
 METRICS_POOL = MetricsPool("l3_attackmitigator", "RPC")
 
 
diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py
index f7cd9d55c..36d1d7b92 100644
--- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py
+++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py
@@ -38,9 +38,8 @@ from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigato
 LOGGER = logging.getLogger(__name__)
 current_dir = os.path.dirname(os.path.abspath(__file__))
 
-# Constants
-DEMO_MODE = False
-ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"]
+# Environment variables
+TEST_ML_MODEL = True if int(os.getenv("TEST_ML_MODEL", 0)) == 1 else False
 BATCH_SIZE = int(os.getenv("BATCH_SIZE", 10))
 METRICS_POOL = MetricsPool("l3_centralizedattackdetector", "RPC")
 
@@ -642,7 +641,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                 LOGGER.debug("Monitoring KPIs performed in {} seconds".format(monitor_kpis_end - monitor_kpis_start))
                 LOGGER.debug("cryptomining_detector_output: {}".format(cryptomining_detector_output[i]))
 
-                if DEMO_MODE:
+                if TEST_ML_MODEL:
                     self.analyze_prediction_accuracy(cryptomining_detector_output[i]["confidence"])
 
                 connection_info = ConnectionInfo(
@@ -679,10 +678,10 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
 
                 # Only notify Attack Mitigator when a cryptomining connection has been detected
                 if cryptomining_detector_output[i]["tag_name"] == "Crypto":
-                    if DEMO_MODE:
+                    if TEST_ML_MODEL:
                         self.attack_connections.append(connection_info)
 
-                    if connection_info.ip_o in ATTACK_IPS or connection_info.ip_d in ATTACK_IPS:
+                    if connection_info.ip_o in self.attack_ips or connection_info.ip_d in self.attack_ips:
                         self.correct_attack_conns += 1
                         self.correct_predictions += 1
                     else:
@@ -747,7 +746,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                     LOGGER.info("No attack detected")
 
                     if cryptomining_detector_output[i]["tag_name"] != "Crypto":
-                        if connection_info.ip_o not in ATTACK_IPS and connection_info.ip_d not in ATTACK_IPS:
+                        if connection_info.ip_o not in self.attack_ips and connection_info.ip_d not in self.attack_ips:
                             self.correct_predictions += 1
                         else:
                             LOGGER.debug("False negative: {}".format(connection_info))
@@ -866,3 +865,21 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
             features_ids.auto_features.append(feature)
 
         return features_ids
+
+    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
+    def SetAttackIPs(self, request, context):
+        """
+        Sets the list of attack IPs in order to be used to compute the prediction accuracy of the Centralized Attack Detector in case of testing the ML model.
+
+        Args:
+            request (AttackIPs): A list of attack IPs.
+            context (Empty): The context of the request.
+
+        Returns:
+            None
+        """
+
+        self.attack_ips = request.attack_ips
+        LOGGER.debug(f"Succesfully set attack IPs: {self.attack_ips}")
+
+        return Empty(message="Attack IPs set.")
-- 
GitLab