Skip to content
Snippets Groups Projects
Commit 651b180e authored by karamchandan's avatar karamchandan
Browse files

Refactored code to enhance testing capabilities:

- Eliminated hardcoded attack connection IPs and introduced a more flexible approach.
- Incorporated a new RPC method and protobuf message in the CAD proto to facilitate setting the attack connection IPs by an external component to enable computation of the ML model's performance for testing purposes.
- Introduced a new environment variable in CAD to enable or disable testing of the ML model accuracy.
parent a87e5ae2
No related branches found
No related tags found
2 merge requests!142Release TeraFlowSDN 2.1,!135Fixed L3 Cybersecurity framework
...@@ -42,6 +42,8 @@ spec: ...@@ -42,6 +42,8 @@ spec:
value: "0.5" value: "0.5"
- name: MONITORED_KPIS_TIME_INTERVAL_AGG - name: MONITORED_KPIS_TIME_INTERVAL_AGG
value: "60" value: "60"
- name: TEST_ML_MODEL
value: "0"
readinessProbe: readinessProbe:
exec: exec:
command: ["/bin/grpc_health_probe", "-addr=:10001"] command: ["/bin/grpc_health_probe", "-addr=:10001"]
......
...@@ -25,6 +25,9 @@ service L3Centralizedattackdetector { ...@@ -25,6 +25,9 @@ service L3Centralizedattackdetector {
// Get the list of features used by the ML model in the CAD component // Get the list of features used by the ML model in the CAD component
rpc GetFeaturesIds (Empty) returns (AutoFeatures) {} rpc GetFeaturesIds (Empty) returns (AutoFeatures) {}
// Sets the list of attack IPs in order to be used to compute the prediction accuracy of the ML model in the CAD component in case of testing the ML model
rpc SetAttackIPs (AttackIPs) returns (Empty) {}
} }
message Feature { message Feature {
...@@ -66,3 +69,7 @@ message L3CentralizedattackdetectorBatchInput { ...@@ -66,3 +69,7 @@ message L3CentralizedattackdetectorBatchInput {
message Empty { message Empty {
string message = 1; string message = 1;
} }
message AttackIPs {
repeated string attack_ips = 1;
}
\ No newline at end of file
...@@ -33,7 +33,6 @@ from service.client.ServiceClient import ServiceClient ...@@ -33,7 +33,6 @@ from service.client.ServiceClient import ServiceClient
from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
METRICS_POOL = MetricsPool("l3_attackmitigator", "RPC") METRICS_POOL = MetricsPool("l3_attackmitigator", "RPC")
......
...@@ -38,9 +38,8 @@ from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigato ...@@ -38,9 +38,8 @@ from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigato
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
# Constants # Environment variables
DEMO_MODE = False TEST_ML_MODEL = True if int(os.getenv("TEST_ML_MODEL", 0)) == 1 else False
ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"]
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 10)) BATCH_SIZE = int(os.getenv("BATCH_SIZE", 10))
METRICS_POOL = MetricsPool("l3_centralizedattackdetector", "RPC") METRICS_POOL = MetricsPool("l3_centralizedattackdetector", "RPC")
...@@ -642,7 +641,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto ...@@ -642,7 +641,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
LOGGER.debug("Monitoring KPIs performed in {} seconds".format(monitor_kpis_end - monitor_kpis_start)) LOGGER.debug("Monitoring KPIs performed in {} seconds".format(monitor_kpis_end - monitor_kpis_start))
LOGGER.debug("cryptomining_detector_output: {}".format(cryptomining_detector_output[i])) LOGGER.debug("cryptomining_detector_output: {}".format(cryptomining_detector_output[i]))
if DEMO_MODE: if TEST_ML_MODEL:
self.analyze_prediction_accuracy(cryptomining_detector_output[i]["confidence"]) self.analyze_prediction_accuracy(cryptomining_detector_output[i]["confidence"])
connection_info = ConnectionInfo( connection_info = ConnectionInfo(
...@@ -679,10 +678,10 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto ...@@ -679,10 +678,10 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
# Only notify Attack Mitigator when a cryptomining connection has been detected # Only notify Attack Mitigator when a cryptomining connection has been detected
if cryptomining_detector_output[i]["tag_name"] == "Crypto": if cryptomining_detector_output[i]["tag_name"] == "Crypto":
if DEMO_MODE: if TEST_ML_MODEL:
self.attack_connections.append(connection_info) self.attack_connections.append(connection_info)
if connection_info.ip_o in ATTACK_IPS or connection_info.ip_d in ATTACK_IPS: if connection_info.ip_o in self.attack_ips or connection_info.ip_d in self.attack_ips:
self.correct_attack_conns += 1 self.correct_attack_conns += 1
self.correct_predictions += 1 self.correct_predictions += 1
else: else:
...@@ -747,7 +746,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto ...@@ -747,7 +746,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
LOGGER.info("No attack detected") LOGGER.info("No attack detected")
if cryptomining_detector_output[i]["tag_name"] != "Crypto": if cryptomining_detector_output[i]["tag_name"] != "Crypto":
if connection_info.ip_o not in ATTACK_IPS and connection_info.ip_d not in ATTACK_IPS: if connection_info.ip_o not in self.attack_ips and connection_info.ip_d not in self.attack_ips:
self.correct_predictions += 1 self.correct_predictions += 1
else: else:
LOGGER.debug("False negative: {}".format(connection_info)) LOGGER.debug("False negative: {}".format(connection_info))
...@@ -866,3 +865,21 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto ...@@ -866,3 +865,21 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
features_ids.auto_features.append(feature) features_ids.auto_features.append(feature)
return features_ids return features_ids
@safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
def SetAttackIPs(self, request, context):
"""
Sets the list of attack IPs in order to be used to compute the prediction accuracy of the Centralized Attack Detector in case of testing the ML model.
Args:
request (AttackIPs): A list of attack IPs.
context (Empty): The context of the request.
Returns:
None
"""
self.attack_ips = request.attack_ips
LOGGER.debug(f"Succesfully set attack IPs: {self.attack_ips}")
return Empty(message="Attack IPs set.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment