Skip to content
Snippets Groups Projects
Commit a87e5ae2 authored by karamchandan's avatar karamchandan
Browse files

Rewrote all docstrings to comply with PEP 257 and added docstrings for all...

Rewrote all docstrings to comply with PEP 257 and added docstrings for all functions that did not have it in CAD and AM service implementations.
parent 2aad0738
No related branches found
No related tags found
2 merge requests!142Release TeraFlowSDN 2.1,!135Fixed L3 Cybersecurity framework
......@@ -34,12 +34,22 @@ from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_m
LOGGER = logging.getLogger(__name__)
METRICS_POOL = MetricsPool('l3_attackmitigator', 'RPC')
METRICS_POOL = MetricsPool("l3_attackmitigator", "RPC")
class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
def __init__(self):
LOGGER.info("Creating Attack Mitigator Service")
"""
Initializes the Attack Mitigator service.
Args:
None.
Returns:
None.
"""
LOGGER.info("Creating Attack Mitigator service")
self.last_value = -1
self.last_tag = 0
......@@ -60,6 +70,23 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
src_port: str,
dst_port: str,
) -> None:
"""
Configures an ACL rule to block undesired TCP traffic.
Args:
context_uuid (str): The UUID of the context.
service_uuid (str): The UUID of the service.
device_uuid (str): The UUID of the device.
endpoint_uuid (str): The UUID of the endpoint.
src_ip (str): The source IP address.
dst_ip (str): The destination IP address.
src_port (str): The source port.
dst_port (str): The destination port.
Returns:
None.
"""
# Create ServiceId
service_id = ServiceId()
service_id.context_id.context_uuid.uuid = context_uuid
......@@ -115,7 +142,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
# Update the Service with the new ACL RuleSet
service_reply: ServiceId = self.service_client.UpdateService(service_request)
LOGGER.info(f"Service reply: {grpc_message_to_json_string(service_reply)}")
if service_reply != service_request.service_id: # pylint: disable=no-member
......@@ -123,10 +150,23 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
@safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
def PerformMitigation(self, request, context):
"""
Performs mitigation on an attack by configuring an ACL rule to block undesired TCP traffic.
Args:
request (L3AttackmitigatorOutput): The request message containing the attack mitigation information.
context (Empty): The context of the request.
Returns:
Empty: An empty response indicating that the attack mitigation information was received and processed.
"""
last_value = request.confidence
last_tag = request.tag
LOGGER.info(f"Attack Mitigator received attack mitigation information. Prediction confidence: {last_value}, Predicted class: {last_tag}")
LOGGER.info(
f"Attack Mitigator received attack mitigation information. Prediction confidence: {last_value}, Predicted class: {last_tag}"
)
ip_o = request.ip_o
ip_d = request.ip_d
......@@ -139,7 +179,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
LOGGER.info(f"Service Id.: {grpc_message_to_json_string(service_id)}")
LOGGER.info("Retrieving service from Context")
while sentinel:
try:
service = self.context_client.GetService(service_id)
......@@ -149,9 +189,11 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
LOGGER.debug(f"Waiting 2 seconds for service to be available (attempt: {counter})")
time.sleep(2)
LOGGER.info(f"Service with Service Id.: {grpc_message_to_json_string(service_id)}\n{grpc_message_to_json_string(service)}")
LOGGER.info(
f"Service with Service Id.: {grpc_message_to_json_string(service_id)}\n{grpc_message_to_json_string(service)}"
)
LOGGER.info("Adding new rule to the service to block the attack")
self.configure_acl_rule(
context_uuid=service_id.context_id.context_uuid.uuid,
service_uuid=service_id.service_uuid.uuid,
......@@ -164,7 +206,7 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
)
LOGGER.info("Service with new rule:\n{}".format(grpc_message_to_json_string(service)))
LOGGER.info("Updating service with the new rule")
self.service_client.UpdateService(service)
service = self.context_client.GetService(service_id)
......@@ -178,6 +220,17 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
@safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
def GetConfiguredACLRules(self, request, context):
"""
Returns the configured ACL rules.
Args:
request (Empty): The request message.
context (Empty): The context of the RPC call.
Returns:
acl_rules (ACLRules): The configured ACL rules.
"""
acl_rules = ACLRules()
for acl_config_rule in self.configured_acl_config_rules:
......
......@@ -65,12 +65,17 @@ class ConnectionInfo:
class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer):
def __init__(self):
"""
Initializes the Centralized Attack Detector service.
"""
Initialize variables, prediction model and clients of components used by CAD
"""
Args:
None
Returns:
None
"""
def __init__(self):
LOGGER.info("Creating Centralized Attack Detector Service")
self.inference_values = []
......@@ -82,14 +87,14 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
)
self.cryptomining_detector_model = rt.InferenceSession(self.cryptomining_detector_model_path)
# Load cryptomining detector features metadata from ONNX file
# Load cryptomining attack detector features metadata from ONNX file
self.cryptomining_detector_features_metadata = list(
self.cryptomining_detector_model.get_modelmeta().custom_metadata_map.values()
)
self.cryptomining_detector_features_metadata = [float(x) for x in self.cryptomining_detector_features_metadata]
self.cryptomining_detector_features_metadata.sort()
LOGGER.info(f"Cryptomining Detector Features: {self.cryptomining_detector_features_metadata}")
LOGGER.info(f"Cryptomining Attack Detector Features: {self.cryptomining_detector_features_metadata}")
LOGGER.info(f"Batch size: {BATCH_SIZE}")
self.input_name = self.cryptomining_detector_model.get_inputs()[0].name
......@@ -121,7 +126,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
self.monitored_kpis = {
"l3_security_status": {
"kpi_id": None,
"description": "L3 - Confidence of the cryptomining detector in the security status in the last time interval of the service {service_id}",
"description": "L3 - Confidence of the cryptomining attack detector in the security status in the last time interval of the service {service_id}",
"kpi_sample_type": KpiSampleType.KPISAMPLETYPE_L3_SECURITY_STATUS_CRYPTO,
"service_ids": [],
},
......@@ -193,16 +198,6 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
writer = csv.writer(file)
writer.writerow(col_names)
"""
Create a monitored KPI for a specific service and add it to the Monitoring Client
-input:
+ service_id: service ID where the KPI will be monitored
+ kpi_name: name of the KPI
+ kpi_description: description of the KPI
+ kpi_sample_type: KPI sample type of the KPI (it must be defined in the kpi_sample_types.proto file)
-output: KPI identifier representing the KPI
"""
def create_kpi(
self,
service_id,
......@@ -210,24 +205,40 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
kpi_description,
kpi_sample_type,
):
"""
Creates a new KPI for a specific service and add it to the Monitoring client
Args:
service_id (ServiceID): The ID of the service.
kpi_name (str): The name of the KPI.
kpi_description (str): The description of the KPI.
kpi_sample_type (KpiSampleType): The sample type of the KPI.
Returns:
kpi (Kpi): The created KPI.
"""
kpidescriptor = KpiDescriptor()
kpidescriptor.kpi_description = kpi_description
kpidescriptor.service_id.service_uuid.uuid = service_id.service_uuid.uuid
kpidescriptor.kpi_sample_type = kpi_sample_type
new_kpi = self.monitoring_client.SetKpi(kpidescriptor)
kpi = self.monitoring_client.SetKpi(kpidescriptor)
LOGGER.info("Created KPI {}".format(kpi_name))
return new_kpi
"""
Create the monitored KPIs for a specific service, add them to the Monitoring Client and store their identifiers in the monitored_kpis dictionary
-input:
+ service_id: service ID where the KPIs will be monitored
-output: None
"""
return kpi
def create_kpis(self, service_id):
"""
Creates the monitored KPIs for a specific service, adds them to the Monitoring client and stores their identifiers in the monitored_kpis dictionary
Args:
service_id (uuid): The ID of the service.
Returns:
None
"""
LOGGER.info("Creating KPIs for service {}".format(service_id))
# all the KPIs are created for all the services from which requests are received
......@@ -244,6 +255,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
LOGGER.info("Created KPIs for service {}".format(service_id))
def monitor_kpis(self):
"""
Monitors KPIs for all the services from which requests are received
Args:
None
Returns:
None
"""
monitor_inference_results = self.inference_results
monitor_service_ids = self.service_ids
......@@ -262,6 +283,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
LOGGER.debug("No KPIs sent to monitoring server")
def assign_timestamp(self, monitor_inference_results):
"""
Assigns a timestamp to the monitored inference results.
Args:
monitor_inference_results (list): A list of monitored inference results.
Returns:
None
"""
time_interval = self.MONITORED_KPIS_TIME_INTERVAL_AGG
# assign the timestamp of the first inference result to the time_interval_start
......@@ -304,6 +335,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
def monitor_compute_l3_kpi(
self,
):
"""
Computes the monitored KPIs for a specific service and sends them to the Monitoring server
Args:
None
Returns:
None
"""
# L3 security status
kpi_security_status = Kpi()
kpi_security_status.kpi_id.kpi_id.CopyFrom(self.monitored_kpis["l3_security_status"]["kpi_id"])
......@@ -358,19 +399,36 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
LOGGER.debug("Error sending KPIs to monitoring server: {}".format(e))
def monitor_ml_model_confidence(self):
if self.l3_security_status == 0:
return self.l3_ml_model_confidence_normal
"""
Get the monitored KPI for the confidence of the ML model
return self.l3_ml_model_confidence_crypto
Args:
None
"""
Classify connection as standard traffic or cryptomining attack and return results
-input:
+ request: L3CentralizedattackdetectorMetrics object with connection features information
-output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence
"""
Returns:
confidence (float): The monitored KPI for the confidence of the ML model
"""
confidence = None
if self.l3_security_status == 0:
confidence = self.l3_ml_model_confidence_normal
else:
confidence = self.l3_ml_model_confidence_crypto
return confidence
def perform_inference(self, request):
"""
Performs inference on the input data using the Cryptomining Attack Detector model to classify the connection as standard traffic or cryptomining attack.
Args:
request (L3CentralizedattackdetectorMetrics): A L3CentralizedattackdetectorMetrics object with connection features information.
Returns:
dict: A dictionary containing the predicted class, the probability of that class, and other relevant information required to block the attack.
"""
x_data = np.array([[feature.feature for feature in request.features]])
# Print input data shape
......@@ -444,14 +502,17 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
return output_message
"""
Classify connection as standard traffic or cryptomining attack and return results
-input:
+ request: L3CentralizedattackdetectorMetrics object with connection features information
-output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence
"""
def perform_batch_inference(self, requests):
"""
Performs batch inference on the input data using the Cryptomining Attack Detector model to classify the connection as standard traffic or cryptomining attack.
Args:
requests (list): A list of L3CentralizedattackdetectorMetrics objects with connection features information.
Returns:
list: A list of dictionaries containing the predicted class, the probability of that class, and other relevant information required to block the attack for each request.
"""
def perform_distributed_inference(self, requests):
batch_size = len(requests)
# Create an empty array to hold the input data
......@@ -534,15 +595,19 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
return output_messages
"""
Receive features from Attack Mitigator, predict attack and communicate with Attack Mitigator
-input:
+ request: L3CentralizedattackdetectorMetrics object with connection features information
-output: Empty object with a message about the execution of the function
"""
@safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
def AnalyzeConnectionStatistics(self, request, context):
"""
Analyzes the connection statistics sent in the request, performs batch inference on the input data using the Cryptomining Attack Detector model to classify the connection as standard traffic or cryptomining attack, and notifies the Attack Mitigator component in case of attack.
Args:
request (L3CentralizedattackdetectorMetrics): A L3CentralizedattackdetectorMetrics object with connection features information.
context (Empty): The context of the request.
Returns:
Empty: An empty response indicating that the information was received and processed.
"""
# Perform inference with the data sent in the request
if len(self.active_requests) == 0:
self.first_batch_request_time = time.time()
......@@ -553,7 +618,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
LOGGER.debug("Performing inference... {}".format(self.pod_id))
inference_time_start = time.time()
cryptomining_detector_output = self.perform_distributed_inference(self.active_requests)
cryptomining_detector_output = self.perform_batch_inference(self.active_requests)
inference_time_end = time.time()
LOGGER.debug("Inference performed in {} seconds".format(inference_time_end - inference_time_start))
......@@ -711,6 +776,16 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
return Empty(message="Ok, information received")
def analyze_prediction_accuracy(self, confidence):
"""
Analyzes the prediction accuracy of the Centralized Attack Detector.
Args:
confidence (float): The confidence level of the Cryptomining Attack Detector model.
Returns:
None
"""
LOGGER.info("Number of Attack Connections Correctly Classified: {}".format(self.correct_attack_conns))
LOGGER.info("Number of Attack Connections: {}".format(len(self.attack_connections)))
......@@ -727,7 +802,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
cryptomining_attack_detection_acc = 0
LOGGER.info("Cryptomining Attack Detection Accuracy: {}".format(cryptomining_attack_detection_acc))
LOGGER.info("Cryptomining Detector Confidence: {}".format(confidence))
LOGGER.info("Cryptomining Attack Detector Confidence: {}".format(confidence))
with open("prediction_accuracy.txt", "a") as f:
LOGGER.debug("Exporting prediction accuracy and confidence")
......@@ -739,12 +814,23 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
f.write("False Positives: {}\n".format(self.false_positives))
f.write("True Negatives: {}\n".format(self.total_predictions - len(self.attack_connections)))
f.write("False Negatives: {}\n".format(self.false_negatives))
f.write("Cryptomining Detector Confidence: {}\n\n".format(confidence))
f.write("Cryptomining Attack Detector Confidence: {}\n\n".format(confidence))
f.write("Timestamp: {}\n".format(datetime.now().strftime("%d/%m/%Y %H:%M:%S")))
f.close()
@safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
def AnalyzeBatchConnectionStatistics(self, request, context):
"""
Analyzes a batch of connection statistics sent in the request, performs batch inference on the input data using the Cryptomining Attack Detector model to classify the connection as standard traffic or cryptomining attack, and notifies the Attack Mitigator component in case of attack.
Args:
request (L3CentralizedattackdetectorBatchMetrics): A L3CentralizedattackdetectorBatchMetrics object with connection features information.
context (Empty): The context of the request.
Returns:
Empty: An empty response indicating that the information was received and processed.
"""
batch_time_start = time.time()
for metric in request.metrics:
......@@ -761,16 +847,22 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
return Empty(message="OK, information received.")
"""
Send features allocated in the metadata of the onnx file to the DAD
-output: ONNX metadata as a list of integers
"""
@safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
def GetFeaturesIds(self, request: Empty, context):
features = AutoFeatures()
def GetFeaturesIds(self, request, context):
"""
Returns a list of feature IDs used by the Cryptomining Attack Detector model.
Args:
request (Empty): An empty request object.
context (Empty): The context of the request.
Returns:
features_ids (AutoFeatures): A list of feature IDs used by the Cryptomining Attack Detector model.
"""
features_ids = AutoFeatures()
for feature in self.cryptomining_detector_features_metadata:
features.auto_features.append(feature)
features_ids.auto_features.append(feature)
return features
return features_ids
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment