diff --git a/manifests/l3_centralizedattackdetectorservice.yaml b/manifests/l3_centralizedattackdetectorservice.yaml index 95c6d8176ca86c98c1e26d88267c864247ae8b5b..8a3be69b672200120afb4bca3892dd0c08ec2d65 100644 --- a/manifests/l3_centralizedattackdetectorservice.yaml +++ b/manifests/l3_centralizedattackdetectorservice.yaml @@ -42,6 +42,8 @@ spec: value: "0.5" - name: MONITORED_KPIS_TIME_INTERVAL_AGG value: "60" + - name: TEST_ML_MODEL + value: "0" readinessProbe: exec: command: ["/bin/grpc_health_probe", "-addr=:10001"] diff --git a/proto/l3_centralizedattackdetector.proto b/proto/l3_centralizedattackdetector.proto index ed99435aa7db6584b381079cb1e3d589fb9998b5..56273cb628b5d4b8517c2640de6a88b0e57dab3d 100644 --- a/proto/l3_centralizedattackdetector.proto +++ b/proto/l3_centralizedattackdetector.proto @@ -25,6 +25,9 @@ service L3Centralizedattackdetector { // Get the list of features used by the ML model in the CAD component rpc GetFeaturesIds (Empty) returns (AutoFeatures) {} + + // Sets the list of attack IPs in order to be used to compute the prediction accuracy of the ML model in the CAD component in case of testing the ML model + rpc SetAttackIPs (AttackIPs) returns (Empty) {} } message Feature { @@ -66,3 +69,7 @@ message L3CentralizedattackdetectorBatchInput { message Empty { string message = 1; } + +message AttackIPs { + repeated string attack_ips = 1; +} \ No newline at end of file diff --git a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py index c1ee4b3bf9b6968eabe25ec6493c68db0849ebb8..ad02f62430d27ed55390bbddd7709a8df52c3387 100644 --- a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py +++ b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py @@ -33,7 +33,6 @@ from service.client.ServiceClient import ServiceClient from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method LOGGER = logging.getLogger(__name__) - METRICS_POOL = MetricsPool("l3_attackmitigator", "RPC") diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py index f7cd9d55c604995ff82e91e3ee001361298fb611..36d1d7b92cbd2bfd7a804667a0c12205e1160d1f 100644 --- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py @@ -38,9 +38,8 @@ from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigato LOGGER = logging.getLogger(__name__) current_dir = os.path.dirname(os.path.abspath(__file__)) -# Constants -DEMO_MODE = False -ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"] +# Environment variables +TEST_ML_MODEL = True if int(os.getenv("TEST_ML_MODEL", 0)) == 1 else False BATCH_SIZE = int(os.getenv("BATCH_SIZE", 10)) METRICS_POOL = MetricsPool("l3_centralizedattackdetector", "RPC") @@ -642,7 +641,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto LOGGER.debug("Monitoring KPIs performed in {} seconds".format(monitor_kpis_end - monitor_kpis_start)) LOGGER.debug("cryptomining_detector_output: {}".format(cryptomining_detector_output[i])) - if DEMO_MODE: + if TEST_ML_MODEL: self.analyze_prediction_accuracy(cryptomining_detector_output[i]["confidence"]) connection_info = ConnectionInfo( @@ -679,10 +678,10 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto # Only notify Attack Mitigator when a cryptomining connection has been detected if cryptomining_detector_output[i]["tag_name"] == "Crypto": - if DEMO_MODE: + if TEST_ML_MODEL: self.attack_connections.append(connection_info) - if connection_info.ip_o in ATTACK_IPS or connection_info.ip_d in ATTACK_IPS: + if connection_info.ip_o in self.attack_ips or connection_info.ip_d in self.attack_ips: self.correct_attack_conns += 1 self.correct_predictions += 1 else: @@ -747,7 +746,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto LOGGER.info("No attack detected") if cryptomining_detector_output[i]["tag_name"] != "Crypto": - if connection_info.ip_o not in ATTACK_IPS and connection_info.ip_d not in ATTACK_IPS: + if connection_info.ip_o not in self.attack_ips and connection_info.ip_d not in self.attack_ips: self.correct_predictions += 1 else: LOGGER.debug("False negative: {}".format(connection_info)) @@ -866,3 +865,21 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto features_ids.auto_features.append(feature) return features_ids + + @safe_and_metered_rpc_method(METRICS_POOL, LOGGER) + def SetAttackIPs(self, request, context): + """ + Sets the list of attack IPs in order to be used to compute the prediction accuracy of the Centralized Attack Detector in case of testing the ML model. + + Args: + request (AttackIPs): A list of attack IPs. + context (Empty): The context of the request. + + Returns: + None + """ + + self.attack_ips = request.attack_ips + LOGGER.debug(f"Succesfully set attack IPs: {self.attack_ips}") + + return Empty(message="Attack IPs set.")