Commit e88cecd0 authored by Luis de la Cal's avatar Luis de la Cal
Browse files

Added buffer to centralized attack detector component for it to wait until it...

Added buffer to centralized attack detector component for it to wait until it receives BATCH_SIZE requests until it makes an inference
parent 9758c5b2
Loading
Loading
Loading
Loading
+0 −2
Original line number Original line Diff line number Diff line
@@ -68,8 +68,6 @@ spec:
  - name: grpc
  - name: grpc
    port: 10002
    port: 10002
    targetPort: 10002
    targetPort: 10002
strategy:
  type: Recreate


---
---
apiVersion: autoscaling/v2
apiVersion: autoscaling/v2
+0 −2
Original line number Original line Diff line number Diff line
@@ -68,8 +68,6 @@ spec:
  - name: grpc
  - name: grpc
    port: 10001
    port: 10001
    targetPort: 10001
    targetPort: 10001
strategy:
  type: Recreate


---
---
apiVersion: autoscaling/v2
apiVersion: autoscaling/v2
+46 −5
Original line number Original line Diff line number Diff line
#!/bin/bash
# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# ----- TeraFlowSDN ------------------------------------------------------------
# ----- TeraFlowSDN ------------------------------------------------------------

# Set the URL of the internal MicroK8s Docker registry where the images will be uploaded to.
export TFS_REGISTRY_IMAGES="http://localhost:32000/tfs/"
export TFS_REGISTRY_IMAGES="http://localhost:32000/tfs/"

# Set the list of components, separated by spaces, you want to build images for, and deploy.
export TFS_COMPONENTS="context device automation monitoring pathcomp service slice compute webui load_generator l3_attackmitigator l3_centralizedattackdetector"
export TFS_COMPONENTS="context device automation monitoring pathcomp service slice compute webui load_generator l3_attackmitigator l3_centralizedattackdetector"

# Set the tag you want to use for your images.
export TFS_IMAGE_TAG="dev"
export TFS_IMAGE_TAG="dev"

# Set the name of the Kubernetes namespace to deploy TFS to.
export TFS_K8S_NAMESPACE="tfs"
export TFS_K8S_NAMESPACE="tfs"

# Set additional manifest files to be applied after the deployment
export TFS_EXTRA_MANIFESTS="manifests/nginx_ingress_http.yaml"
export TFS_EXTRA_MANIFESTS="manifests/nginx_ingress_http.yaml"

# Set the new Grafana admin password
export TFS_GRAFANA_PASSWORD="admin123+"
export TFS_GRAFANA_PASSWORD="admin123+"

# Disable skip-build flag to rebuild the Docker images.
export TFS_SKIP_BUILD=""
export TFS_SKIP_BUILD=""



# ----- CockroachDB ------------------------------------------------------------
# ----- CockroachDB ------------------------------------------------------------

# Set the namespace where CockroackDB will be deployed.
export CRDB_NAMESPACE="crdb"
export CRDB_NAMESPACE="crdb"


# Set the external port CockroackDB Postgre SQL interface will be exposed to.
# Set the external port CockroackDB Postgre SQL interface will be exposed to.
@@ -18,15 +51,27 @@ export CRDB_EXT_PORT_HTTP="8081"


# Set the database username to be used by Context.
# Set the database username to be used by Context.
export CRDB_USERNAME="tfs"
export CRDB_USERNAME="tfs"

# Set the database user's password to be used by Context.
export CRDB_PASSWORD="tfs123"
export CRDB_PASSWORD="tfs123"

# Set the database name to be used by Context.
export CRDB_DATABASE="tfs"
export CRDB_DATABASE="tfs"

# Set CockroachDB installation mode to 'single'. This option is convenient for development and testing.
# See ./deploy/all.sh or ./deploy/crdb.sh for additional details
export CRDB_DEPLOY_MODE="single"
export CRDB_DEPLOY_MODE="single"


# Disable flag for dropping database, if it exists.
# Disable flag for dropping database, if it exists.
export CRDB_DROP_DATABASE_IF_EXISTS=""
export CRDB_DROP_DATABASE_IF_EXISTS=""

# Disable flag for re-deploying CockroachDB from scratch.
export CRDB_REDEPLOY=""
export CRDB_REDEPLOY=""



# ----- NATS -------------------------------------------------------------------
# ----- NATS -------------------------------------------------------------------

# Set the namespace where NATS will be deployed.
export NATS_NAMESPACE="nats"
export NATS_NAMESPACE="nats"


# Set the external port NATS Client interface will be exposed to.
# Set the external port NATS Client interface will be exposed to.
@@ -38,6 +83,7 @@ export NATS_EXT_PORT_HTTP="8222"
# Disable flag for re-deploying NATS from scratch.
# Disable flag for re-deploying NATS from scratch.
export NATS_REDEPLOY=""
export NATS_REDEPLOY=""



# ----- QuestDB ----------------------------------------------------------------
# ----- QuestDB ----------------------------------------------------------------


# Set the namespace where QuestDB will be deployed.
# Set the namespace where QuestDB will be deployed.
@@ -69,8 +115,3 @@ export QDB_DROP_TABLES_IF_EXIST=""


# Disable flag for re-deploying QuestDB from scratch.
# Disable flag for re-deploying QuestDB from scratch.
export QDB_REDEPLOY=""
export QDB_REDEPLOY=""

export CRDB_DROP_DATABASE_IF_EXISTS="YES"
export CRDB_REDEPLOY="YES"
export NATS_REDEPLOY="YES"
export QDB_REDEPLOY="TRUE"
 No newline at end of file
+8 −0
Original line number Original line Diff line number Diff line
@@ -30,8 +30,12 @@ from common.tools.grpc.Tools import grpc_message_to_json_string
from context.client.ContextClient import ContextClient
from context.client.ContextClient import ContextClient
from service.client.ServiceClient import ServiceClient
from service.client.ServiceClient import ServiceClient


from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method

LOGGER = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)


METRICS_POOL = MetricsPool('l3_attackmitigator', 'RPC')



class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
    def __init__(self):
    def __init__(self):
@@ -123,6 +127,8 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):
        if service_reply != service_request.service_id:  # pylint: disable=no-member
        if service_reply != service_request.service_id:  # pylint: disable=no-member
            raise Exception("Service update failed. Wrong ServiceId was returned")
            raise Exception("Service update failed. Wrong ServiceId was returned")



    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def PerformMitigation(self, request, context):
    def PerformMitigation(self, request, context):
        last_value = request.confidence
        last_value = request.confidence
        last_tag = request.tag
        last_tag = request.tag
@@ -180,6 +186,8 @@ class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer):


        return Empty(message=f"OK, received values: {last_tag} with confidence {last_value}.")
        return Empty(message=f"OK, received values: {last_tag} with confidence {last_value}.")



    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def GetConfiguredACLRules(self, request, context):
    def GetConfiguredACLRules(self, request, context):
        acl_rules = ACLRules()
        acl_rules = ACLRules()


+239 −131
Original line number Original line Diff line number Diff line
@@ -40,14 +40,20 @@ from l3_attackmitigator.client.l3_attackmitigatorClient import l3_attackmitigato


import uuid
import uuid


from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method



LOGGER = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)
current_dir = os.path.dirname(os.path.abspath(__file__))
current_dir = os.path.dirname(os.path.abspath(__file__))


# Demo constants
# Demo constants
DEMO_MODE = True
DEMO_MODE = False
ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"]
ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"]


BATCH_SIZE= 10

METRICS_POOL = MetricsPool('l3_centralizedattackdetector', 'RPC')



class ConnectionInfo:
class ConnectionInfo:
    def __init__(self, ip_o, port_o, ip_d, port_d):
    def __init__(self, ip_o, port_o, ip_d, port_d):
@@ -94,7 +100,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
        self.cryptomining_detector_features_metadata.sort()
        self.cryptomining_detector_features_metadata.sort()
        LOGGER.info("Cryptomining Detector Features: " + str(self.cryptomining_detector_features_metadata))
        LOGGER.info("Cryptomining Detector Features: " + str(self.cryptomining_detector_features_metadata))
        
        
        LOGGER.info("CHANGE CHECK 3")
        LOGGER.info("Batch size: " + BATCH_SIZE)


        self.input_name = self.cryptomining_detector_model.get_inputs()[0].name
        self.input_name = self.cryptomining_detector_model.get_inputs()[0].name
        self.label_name = self.cryptomining_detector_model.get_outputs()[0].name
        self.label_name = self.cryptomining_detector_model.get_outputs()[0].name
@@ -118,6 +124,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto


        self.l3_non_empty_time_interval = False
        self.l3_non_empty_time_interval = False
        
        
        self.active_requests = []

        self.monitoring_client = MonitoringClient()
        self.monitoring_client = MonitoringClient()
        self.service_ids = []
        self.service_ids = []
        self.monitored_kpis = {
        self.monitored_kpis = {
@@ -452,19 +460,110 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto


        return output_message
        return output_message
    
    
    """
    Classify connection as standard traffic or cryptomining attack and return results
        -input: 
            + request: L3CentralizedattackdetectorMetrics object with connection features information
        -output: L3AttackmitigatorOutput object with information about the assigned class and prediction confidence
    """

    def perform_distributed_inference(self, requests):
        batch_size = len(requests)

        # Create an empty array to hold the input data
        x_data = np.empty((batch_size, len(requests[0].features)))

        # Fill in the input data array with features from each request
        for i, request in enumerate(requests):
            x_data[i] = [feature.feature for feature in request.features]

        # Print input data shape
        LOGGER.debug("x_data.shape: {}".format(x_data.shape))

        inference_time_start = time.perf_counter()

        # Perform inference
        predictions = self.cryptomining_detector_model.run(
            [self.prob_name], {self.input_name: x_data.astype(np.float32)}
        )[0]

        inference_time_end = time.perf_counter()

        # Measure inference time
        inference_time = inference_time_end - inference_time_start
        self.cad_inference_times.append(inference_time)

        if len(self.cad_inference_times) > self.cad_num_inference_measurements:
            inference_times_np_array = np.array(self.cad_inference_times)
            np.save(f"inference_times_{batch_size}.npy", inference_times_np_array)

            avg_inference_time = np.mean(inference_times_np_array)
            max_inference_time = np.max(inference_times_np_array)
            min_inference_time = np.min(inference_times_np_array)
            std_inference_time = np.std(inference_times_np_array)
            median_inference_time = np.median(inference_times_np_array)

            LOGGER.debug("Average inference time: {}".format(avg_inference_time))
            LOGGER.debug("Max inference time: {}".format(max_inference_time))
            LOGGER.debug("Min inference time: {}".format(min_inference_time))
            LOGGER.debug("Standard deviation inference time: {}".format(std_inference_time))
            LOGGER.debug("Median inference time: {}".format(median_inference_time))

            with open(f"inference_times_stats_{batch_size}.txt", "w") as f:
                f.write("Average inference time: {}\n".format(avg_inference_time))
                f.write("Max inference time: {}\n".format(max_inference_time))
                f.write("Min inference time: {}\n".format(min_inference_time))
                f.write("Standard deviation inference time: {}\n".format(std_inference_time))
                f.write("Median inference time: {}\n".format(median_inference_time))

        # Gather the predicted class, the probability of that class and other relevant information required to block the attack
        output_messages = []
        for i, request in enumerate(requests):
            output_messages.append({
                "confidence": None,
                "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
                "ip_o": request.connection_metadata.ip_o,
                "ip_d": request.connection_metadata.ip_d,
                "tag_name": None,
                "tag": None,
                "flow_id": request.connection_metadata.flow_id,
                "protocol": request.connection_metadata.protocol,
                "port_o": request.connection_metadata.port_o,
                "port_d": request.connection_metadata.port_d,
                "ml_id": self.cryptomining_detector_file_name,
                "service_id": request.connection_metadata.service_id,
                "endpoint_id": request.connection_metadata.endpoint_id,
                "time_start": request.connection_metadata.time_start,
                "time_end": request.connection_metadata.time_end,
            })

            if predictions[i][1] >= self.CLASSIFICATION_THRESHOLD:
                output_messages[i]["confidence"] = predictions[i][1]
                output_messages[i]["tag_name"] = "Crypto"
                output_messages[i]["tag"] = self.CRYPTO_CLASS
            else:
                output_messages[i]["confidence"] = predictions[i][0]
                output_messages[i]["tag_name"] = "Normal"
                output_messages[i]["tag"] = self.NORMAL_CLASS

        return output_messages

    """
    """
    Receive features from Attack Mitigator, predict attack and communicate with Attack Mitigator
    Receive features from Attack Mitigator, predict attack and communicate with Attack Mitigator
        -input: 
        -input: 
            + request: L3CentralizedattackdetectorMetrics object with connection features information
            + request: L3CentralizedattackdetectorMetrics object with connection features information
        -output: Empty object with a message about the execution of the function
        -output: Empty object with a message about the execution of the function
    """
    """

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def AnalyzeConnectionStatistics(self, request, context):
    def AnalyzeConnectionStatistics(self, request, context):
        # Perform inference with the data sent in the request
        # Perform inference with the data sent in the request
        self.active_requests.append(request)
        
        if len(self.active_requests) == BATCH_SIZE:
            logging.info("Performing inference...")
            logging.info("Performing inference...")
            
            
            inference_time_start = time.time()
            inference_time_start = time.time()
        cryptomining_detector_output = self.perform_inference(request)
            cryptomining_detector_output = self.perform_distributed_inference(self.active_requests)
            inference_time_end = time.time()
            inference_time_end = time.time()
            
            
            LOGGER.debug("Inference performed in {} seconds".format(inference_time_end - inference_time_start))
            LOGGER.debug("Inference performed in {} seconds".format(inference_time_end - inference_time_start))
@@ -473,9 +572,10 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
            self.inference_results.append({"output": cryptomining_detector_output, "timestamp": datetime.now()})
            self.inference_results.append({"output": cryptomining_detector_output, "timestamp": datetime.now()})
            LOGGER.debug("inference_results length: {}".format(len(self.inference_results)))
            LOGGER.debug("inference_results length: {}".format(len(self.inference_results)))


        service_id = request.connection_metadata.service_id
            for i, req in enumerate(self.active_requests):
        device_id = request.connection_metadata.endpoint_id.device_id
                service_id = req.connection_metadata.service_id
        endpoint_id = request.connection_metadata.endpoint_id
                device_id = req.connection_metadata.endpoint_id.device_id
                endpoint_id = req.connection_metadata.endpoint_id


                # Check if a request of a new service has been received and, if so, create the monitored KPIs for that service
                # Check if a request of a new service has been received and, if so, create the monitored KPIs for that service
                if service_id not in self.service_ids:
                if service_id not in self.service_ids:
@@ -487,27 +587,27 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                monitor_kpis_end = time.time()
                monitor_kpis_end = time.time()


                LOGGER.debug("Monitoring KPIs performed in {} seconds".format(monitor_kpis_end - monitor_kpis_start))
                LOGGER.debug("Monitoring KPIs performed in {} seconds".format(monitor_kpis_end - monitor_kpis_start))
        LOGGER.debug("cryptomining_detector_output: {}".format(cryptomining_detector_output))
                LOGGER.debug("cryptomining_detector_output: {}".format(cryptomining_detector_output[i]))


                if DEMO_MODE:
                if DEMO_MODE:
            self.analyze_prediction_accuracy(cryptomining_detector_output["confidence"])
                    self.analyze_prediction_accuracy(cryptomining_detector_output[i]["confidence"])


                connection_info = ConnectionInfo(
                connection_info = ConnectionInfo(
            request.connection_metadata.ip_o,
                    req.connection_metadata.ip_o,
            request.connection_metadata.port_o,
                    req.connection_metadata.port_o,
            request.connection_metadata.ip_d,
                    req.connection_metadata.ip_d,
            request.connection_metadata.port_d,
                    req.connection_metadata.port_d,
                )
                )


                self.l3_non_empty_time_interval = True
                self.l3_non_empty_time_interval = True


        if cryptomining_detector_output["tag_name"] == "Crypto":
                if cryptomining_detector_output[i]["tag_name"] == "Crypto":
                    self.l3_security_status = 1
                    self.l3_security_status = 1


                    self.l3_inferences_in_interval_counter_crypto += 1
                    self.l3_inferences_in_interval_counter_crypto += 1
                    self.l3_ml_model_confidence_crypto = (
                    self.l3_ml_model_confidence_crypto = (
                        self.l3_ml_model_confidence_crypto * (self.l3_inferences_in_interval_counter_crypto - 1)
                        self.l3_ml_model_confidence_crypto * (self.l3_inferences_in_interval_counter_crypto - 1)
                + cryptomining_detector_output["confidence"]
                        + cryptomining_detector_output[i]["confidence"]
                    ) / self.l3_inferences_in_interval_counter_crypto
                    ) / self.l3_inferences_in_interval_counter_crypto


                    if connection_info not in self.l3_attacks:
                    if connection_info not in self.l3_attacks:
@@ -521,11 +621,12 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                    self.l3_inferences_in_interval_counter_normal += 1
                    self.l3_inferences_in_interval_counter_normal += 1
                    self.l3_ml_model_confidence_normal = (
                    self.l3_ml_model_confidence_normal = (
                        self.l3_ml_model_confidence_normal * (self.l3_inferences_in_interval_counter_normal - 1)
                        self.l3_ml_model_confidence_normal * (self.l3_inferences_in_interval_counter_normal - 1)
                + cryptomining_detector_output["confidence"]
                        + cryptomining_detector_output[i]["confidence"]
                    ) / self.l3_inferences_in_interval_counter_normal
                    ) / self.l3_inferences_in_interval_counter_normal


                # Only notify Attack Mitigator when a cryptomining connection has been detected
                # Only notify Attack Mitigator when a cryptomining connection has been detected
        if cryptomining_detector_output["tag_name"] == "Crypto" and connection_info not in self.attack_connections:
                if cryptomining_detector_output[i]["tag_name"] == "Crypto":
                    if DEMO_MODE:
                        self.attack_connections.append(connection_info)
                        self.attack_connections.append(connection_info)


                    if connection_info.ip_o in ATTACK_IPS or connection_info.ip_d in ATTACK_IPS:
                    if connection_info.ip_o in ATTACK_IPS or connection_info.ip_d in ATTACK_IPS:
@@ -549,7 +650,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto


                    try:
                    try:
                        logging.info("Sending the connection information to the Attack Mitigator component...")
                        logging.info("Sending the connection information to the Attack Mitigator component...")
                message = L3AttackmitigatorOutput(**cryptomining_detector_output)
                        message = L3AttackmitigatorOutput(**cryptomining_detector_output[i])
                        response = self.attackmitigator_client.PerformMitigation(message)
                        response = self.attackmitigator_client.PerformMitigation(message)
                        notification_time_end = time.perf_counter()
                        notification_time_end = time.perf_counter()


@@ -584,7 +685,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                        # logging.info("Attack Mitigator notified and received response: ", response.message)  # FIX No message received
                        # logging.info("Attack Mitigator notified and received response: ", response.message)  # FIX No message received
                        logging.info("Attack Mitigator notified")
                        logging.info("Attack Mitigator notified")


                return Empty(message="OK, information received and mitigator notified abou the attack")
                        #return Empty(message="OK, information received and mitigator notified abou the attack")
                    
                    except Exception as e:
                    except Exception as e:
                        logging.error("Error notifying the Attack Mitigator component about the attack: ", e)
                        logging.error("Error notifying the Attack Mitigator component about the attack: ", e)
                        logging.error("Couldn't find l3_attackmitigator")
                        logging.error("Couldn't find l3_attackmitigator")
@@ -593,7 +695,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                else:
                else:
                    logging.info("No attack detected")
                    logging.info("No attack detected")


            if cryptomining_detector_output["tag_name"] != "Crypto":
                    if cryptomining_detector_output[i]["tag_name"] != "Crypto":
                        if connection_info.ip_o not in ATTACK_IPS and connection_info.ip_d not in ATTACK_IPS:
                        if connection_info.ip_o not in ATTACK_IPS and connection_info.ip_d not in ATTACK_IPS:
                            self.correct_predictions += 1
                            self.correct_predictions += 1
                        else:
                        else:
@@ -602,7 +704,12 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto


                        self.total_predictions += 1
                        self.total_predictions += 1


            return Empty(message="Ok, information received (no attack detected)")
                    # return Empty(message="Ok, information received (no attack detected)")
            
            self.active_requests = []
            return Empty(message="Ok, metrics processed")
            
        return Empty(message="Ok, information received")


    def analyze_prediction_accuracy(self, confidence):
    def analyze_prediction_accuracy(self, confidence):
        LOGGER.info("Number of Attack Connections Correctly Classified: {}".format(self.correct_attack_conns))
        LOGGER.info("Number of Attack Connections Correctly Classified: {}".format(self.correct_attack_conns))
@@ -637,6 +744,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
            f.write("Timestamp: {}\n".format(datetime.now().strftime("%d/%m/%Y %H:%M:%S")))
            f.write("Timestamp: {}\n".format(datetime.now().strftime("%d/%m/%Y %H:%M:%S")))
            f.close()
            f.close()


    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def AnalyzeBatchConnectionStatistics(self, request, context):
    def AnalyzeBatchConnectionStatistics(self, request, context):
        batch_time_start = time.time()
        batch_time_start = time.time()


@@ -658,7 +766,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
    Send features allocated in the metadata of the onnx file to the DAD
    Send features allocated in the metadata of the onnx file to the DAD
        -output: ONNX metadata as a list of integers
        -output: ONNX metadata as a list of integers
    """
    """

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def GetFeaturesIds(self, request: Empty, context):
    def GetFeaturesIds(self, request: Empty, context):
        features = AutoFeatures()
        features = AutoFeatures()