Skip to content
Snippets Groups Projects
Commit 651b180e authored by karamchandan's avatar karamchandan
Browse files

Refactored code to enhance testing capabilities:

- Eliminated hardcoded attack connection IPs and introduced a more flexible approach.
- Incorporated a new RPC method and protobuf message in the CAD proto to facilitate setting the attack connection IPs by an external component to enable computation of the ML model's performance for testing purposes.
- Introduced a new environment variable in CAD to enable or disable testing of the ML model accuracy.
parent a87e5ae2
No related branches found
No related tags found
2 merge requests!142Release TeraFlowSDN 2.1,!135Fixed L3 Cybersecurity framework
......@@ -42,6 +42,8 @@ spec:
value: "0.5"
- name: MONITORED_KPIS_TIME_INTERVAL_AGG
value: "60"
- name: TEST_ML_MODEL
value: "0"
readinessProbe:
exec:
command: ["/bin/grpc_health_probe", "-addr=:10001"]
......
......@@ -25,6 +25,9 @@ service L3Centralizedattackdetector {
// Get the list of features used by the ML model in the CAD component
rpc GetFeaturesIds (Empty) returns (AutoFeatures) {}
// Sets the list of attack IPs in order to be used to compute the prediction accuracy of the ML model in the CAD component in case of testing the ML model
rpc SetAttackIPs (AttackIPs) returns (Empty) {}
}
message Feature {
......@@ -66,3 +69,7 @@ message L3CentralizedattackdetectorBatchInput {
message Empty {
string message = 1;
}
message AttackIPs {
repeated string attack_ips = 1;
}
\ No newline at end of file
......@@ -33,7 +33,6 @@ from service.client.ServiceClient import ServiceClient
from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method
LOGGER = logging.getLogger(__name__)
METRICS_POOL = MetricsPool("l3_attackmitigator", "RPC")
......
......@@ -38,9 +38,8 @@ from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigato
LOGGER = logging.getLogger(__name__)
current_dir = os.path.dirname(os.path.abspath(__file__))
# Constants
DEMO_MODE = False
ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"]
# Environment variables
TEST_ML_MODEL = True if int(os.getenv("TEST_ML_MODEL", 0)) == 1 else False
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 10))
METRICS_POOL = MetricsPool("l3_centralizedattackdetector", "RPC")
......@@ -642,7 +641,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
LOGGER.debug("Monitoring KPIs performed in {} seconds".format(monitor_kpis_end - monitor_kpis_start))
LOGGER.debug("cryptomining_detector_output: {}".format(cryptomining_detector_output[i]))
if DEMO_MODE:
if TEST_ML_MODEL:
self.analyze_prediction_accuracy(cryptomining_detector_output[i]["confidence"])
connection_info = ConnectionInfo(
......@@ -679,10 +678,10 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
# Only notify Attack Mitigator when a cryptomining connection has been detected
if cryptomining_detector_output[i]["tag_name"] == "Crypto":
if DEMO_MODE:
if TEST_ML_MODEL:
self.attack_connections.append(connection_info)
if connection_info.ip_o in ATTACK_IPS or connection_info.ip_d in ATTACK_IPS:
if connection_info.ip_o in self.attack_ips or connection_info.ip_d in self.attack_ips:
self.correct_attack_conns += 1
self.correct_predictions += 1
else:
......@@ -747,7 +746,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
LOGGER.info("No attack detected")
if cryptomining_detector_output[i]["tag_name"] != "Crypto":
if connection_info.ip_o not in ATTACK_IPS and connection_info.ip_d not in ATTACK_IPS:
if connection_info.ip_o not in self.attack_ips and connection_info.ip_d not in self.attack_ips:
self.correct_predictions += 1
else:
LOGGER.debug("False negative: {}".format(connection_info))
......@@ -866,3 +865,21 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
features_ids.auto_features.append(feature)
return features_ids
@safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
def SetAttackIPs(self, request, context):
"""
Sets the list of attack IPs in order to be used to compute the prediction accuracy of the Centralized Attack Detector in case of testing the ML model.
Args:
request (AttackIPs): A list of attack IPs.
context (Empty): The context of the request.
Returns:
None
"""
self.attack_ips = request.attack_ips
LOGGER.debug(f"Succesfully set attack IPs: {self.attack_ips}")
return Empty(message="Attack IPs set.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment