Skip to content
Snippets Groups Projects
l3_centralizedattackdetectorServiceServicerImpl.py 4.44 KiB
Newer Older
# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ldemarcosm's avatar
ldemarcosm committed
from __future__ import print_function
from datetime import datetime
import os
import grpc
import numpy as np
import onnxruntime as rt
ldemarcosm's avatar
ldemarcosm committed
import logging
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
from common.proto.l3_centralizedattackdetector_pb2 import (
ldemarcosm's avatar
ldemarcosm committed
    Empty,
ldemarcosm's avatar
ldemarcosm committed
)
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
from common.proto.l3_centralizedattackdetector_pb2_grpc import (
ldemarcosm's avatar
ldemarcosm committed
    L3CentralizedattackdetectorServicer,
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
from common.proto.l3_attackmitigator_pb2 import (
ldemarcosm's avatar
ldemarcosm committed
)
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
from common.proto.l3_attackmitigator_pb2_grpc import (
ldemarcosm's avatar
ldemarcosm committed
    L3AttackmitigatorStub,
ldemarcosm's avatar
ldemarcosm committed
)

LOGGER = logging.getLogger(__name__)
here = os.path.dirname(os.path.abspath(__file__))
MODEL_FILE = os.path.join(here, "ml_model/teraflow_rf.onnx")
ldemarcosm's avatar
ldemarcosm committed

class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer):
ldemarcosm's avatar
ldemarcosm committed

ldemarcosm's avatar
ldemarcosm committed
    def __init__(self):
ldemarcosm's avatar
ldemarcosm committed
        LOGGER.debug("Creating Servicer...")
        self.inference_values = []
        self.model = rt.InferenceSession(MODEL_FILE)
        self.input_name = self.model.get_inputs()[0].name
        self.label_name = self.model.get_outputs()[0].name
        self.prob_name = self.model.get_outputs()[1].name
ldemarcosm's avatar
ldemarcosm committed

    def make_inference(self, request):
        # ML MODEL
        x_data = np.array([
ldemarcosm's avatar
ldemarcosm committed
                [
                    request.n_packets_server_seconds,
                    request.n_packets_client_seconds,
                    request.n_bits_server_seconds,
                    request.n_bits_client_seconds,
                    request.n_bits_server_n_packets_server,
                    request.n_bits_client_n_packets_client,
                    request.n_packets_server_n_packets_client,
                    request.n_bits_server_n_bits_client,
                ]
        predictions = self.model.run(
            [self.prob_name], {self.input_name: x_data.astype(np.float32)})[0]
ldemarcosm's avatar
ldemarcosm committed
        # Output format
        output_message = {
            "confidence": None,
            "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
            "ip_o": request.ip_o,
            "tag_name": None,
            "tag": None,
            "flow_id": request.flow_id,
            "protocol": request.protocol,
            "port_d": request.port_d,
            "ml_id": "RandomForest",
            "time_start": request.time_start,
            "time_end": request.time_end,
        }
        if predictions[0][1] >= 0.5:
ldemarcosm's avatar
ldemarcosm committed
            output_message["confidence"] = predictions[0][1]
            output_message["tag_name"] = "Crypto"
            output_message["tag"] = 1
        else:
            output_message["confidence"] = predictions[0][0]
            output_message["tag_name"] = "Normal"
            output_message["tag"] = 0

        return L3AttackmitigatorOutput(**output_message)
ldemarcosm's avatar
ldemarcosm committed

    def SendInput(self, request, context):
        # PERFORM INFERENCE WITH SENT INPUTS
        logging.debug("")
        print("Inferencing ...")

        # STORE VALUES
        self.inference_values.append(request)

        # MAKE INFERENCE
        output = self.make_inference(request)

        # SEND INFO TO MITIGATION SERVER
        try:
                with grpc.insecure_channel("localhost:10002") as channel:
                    stub = L3AttackmitigatorStub(channel)
                    print("Sending to mitigator...")
                    response = stub.SendOutput(output)
                    print("Sent output to mitigator and received: ", response.message)

                # RETURN "OK" TO THE CALLER
                return Empty(
                    message="OK, information received and mitigator notified"
                )
        except:
            print('Couldnt find l3_attackmitigator')
            return Empty(
                message="Mitigator Not found"
            )

ldemarcosm's avatar
ldemarcosm committed
    def GetOutput(self, request, context):
        logging.debug("")
        print("Returing inference output...")
        k = np.multiply(self.inference_values, [2])
ldemarcosm's avatar
ldemarcosm committed
        k = np.sum(k)
        return self.make_inference(k)