Skip to content
Snippets Groups Projects
l3_distributedattackdetector.py 13.1 KiB
Newer Older
# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from sys import stdout
import os
import time
import signal
import grpc
import numpy as np
import asyncio
from common.proto import L3CentralizedattackdetectorStub
from common.proto.l3_centralizedattackdetector_pb2 import (
    L3CentralizedattackdetectorMetrics,
    L3CentralizedattackdetectorBatchInput,
    ConnectionMetadata,
    Feature,
    Empty,
)

from common.proto.context_pb2 import (
    ServiceTypeEnum,
    ContextId,
)

from common.proto.context_pb2_grpc import ContextServiceStub

#  Setup LOGGER
LOGGER = logging.getLogger("dad_LOGGER")
LOGGER.setLevel(logging.INFO)
logFormatter = logging.Formatter(fmt="%(levelname)-8s %(message)s")
consoleHandler = logging.StreamHandler(stdout)
consoleHandler.setFormatter(logFormatter)
LOGGER.addHandler(consoleHandler)

TSTAT_DIR_NAME = "piped/"
CENTRALIZED_ATTACK_DETECTOR = "192.168.165.78:10001"
JSON_BLANK = {
    "ip_o": "",  # Client IP
    "port_o": "",  # Client port
    "ip_d": "",  # Server ip
    "port_d": "",  # Server port
    "flow_id": "",  # Identifier:c_ip,c_port,s_ip,s_port,time_start
    "protocol": "",  # Connection protocol
    "time_start": 0.0,  # Start of connection
    "time_end": 0.0,  # Time of last packet
}

STOP = False
IGNORE_FIRST_LINE_TSTAT = True

CONTEXT_ID = "admin"
CONTEXT_CHANNEL = "192.168.165.78:1010"
PROFILING = False
SEND_DATA_IN_BATCHES = False
BATCH_SIZE = 10
ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"]

class l3_distributedattackdetector():
    def __init__(self):
        LOGGER.info("Creating Distributed Attack Detector")
        
        self.feature_ids = []
        
        self.cad_features = {}
        self.conn_id = ()
        
        self.connections_dict = {} # Dict for storing ALL data
        self.new_connections = {} # Dict for storing NEW data
        
        signal.signal(signal.SIGINT, self.handler)
        
        with grpc.insecure_channel(CENTRALIZED_ATTACK_DETECTOR) as channel:
            self.cad = L3CentralizedattackdetectorStub(channel)
            LOGGER.info("Connected to the centralized attack detector")

            LOGGER.info("Obtaining features...")
            self.feature_ids = self.get_features_ids()
            LOGGER.info("Features Ids.: {0}".format(self.feature_ids))
            
            asyncio.run(self.process_traffic())
            
        
    def handler(self):
        if STOP:
            exit()

        STOP = True

        LOGGER.info("Gracefully stopping...")
    
    def follow(self, thefile, time_sleep):
        """
        Generator function that yields new lines in a file
        It reads the logfie (the opened file)
        """
        # seek the end of the file
        # thefile.seek(0, os.SEEK_END)

        trozo = ""

        # start infinite loop
        while True:
            # read last line of file
            line = thefile.readline()

            # sleep if file hasn't been updated
            if not line:
                time.sleep(time_sleep)
                continue
            if line[-1] != "\n":
                trozo += line
            else:
                if trozo != "":
                    line = trozo + line
                    trozo = ""
                yield line


    def load_file(self, dirname=TSTAT_DIR_NAME):  # - Client side -
        while True:
            here = os.path.dirname(os.path.abspath(__file__))
            tstat_piped = os.path.join(here, dirname)
            tstat_dirs = os.listdir(tstat_piped)
            if len(tstat_dirs) > 0:
                tstat_dirs.sort()
                new_dir = tstat_dirs[-1]
                tstat_file = tstat_piped + new_dir + "/log_tcp_temp_complete"
                LOGGER.info("Following: {0}".format(tstat_file))
                return tstat_file
            else:
                LOGGER.info("No Tstat directory!")
                time.sleep(5)


    def process_line(self, line):
        """
        - Preprocessing before a message per line
        - Avoids crash when nan are found by generating a 0s array
        - Returns a list of values
        """
        line = line.split(" ")

        try:
            values = []
            for feature_id in self.feature_ids:
                feature_id = int(feature_id)
                feature = feature_id - 1
                values.append(float(line[feature]))
        except IndexError:
            print("IndexError: {0}".format(line))

        return values


    def get_service_ids(self, context_id_str):
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
            stub = ContextServiceStub(channel)
            context_id = ContextId()
            context_id.context_uuid.uuid = context_id_str
            return stub.ListServiceIds(context_id)


    def get_services(self, context_id_str):
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
            stub = ContextServiceStub(channel)
            context_id = ContextId()
            context_id.context_uuid.uuid = context_id_str
            return stub.ListServices(context_id)


    def get_service_id(self, context_id):
        service_id_list = self.get_service_ids(context_id)
        service_id = None
        for s_id in service_id_list.service_ids:
            if (
                s_id.service_uuid.uuid == "0eaa0752-c7b6-4c2e-97da-317fbfee5112"
            ):  # TODO: Change this identifier to the L3VPN service identifier with the real router for the demo v2
                service_id = s_id
                break

        return service_id


    def get_service_id2(self, context_id):
        service_list = self.get_services(context_id)
        service_id = None
        for s in service_list.services:
            if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM:
                service_id = s.service_id
                break
            else:
                pass
        return service_id


    def get_endpoint_id(self, context_id):
        service_list = self.get_services(context_id)
        endpoint_id = None
        for s in service_list.services:
            if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM:
                endpoint_id = s.service_endpoint_ids[0]
                break
        return endpoint_id


    def get_features_ids(self):
        return self.cad.GetFeaturesIds(Empty()).auto_features


    def check_types(self):
        for feature in self.cad_features["features"]:
            assert isinstance(feature, float)

        assert isinstance(self.cad_features["connection_metadata"]["ip_o"], str)
        assert isinstance(self.cad_features["connection_metadata"]["port_o"], str)
        assert isinstance(self.cad_features["connection_metadata"]["ip_d"], str)
        assert isinstance(self.cad_features["connection_metadata"]["port_d"], str)
        assert isinstance(self.cad_features["connection_metadata"]["flow_id"], str)
        assert isinstance(self.cad_features["connection_metadata"]["protocol"], str)
        assert isinstance(self.cad_features["connection_metadata"]["time_start"], float)
        assert isinstance(self.cad_features["connection_metadata"]["time_end"], float)


    def insert_connection(self):
        try:
            self.connections_dict[self.conn_id]["time_end"] = time.time()
        except KeyError:
            self.connections_dict[self.conn_id] = JSON_BLANK.copy()
            self.connections_dict[self.conn_id]["time_start"] = time.time()
            self.connections_dict[self.conn_id]["time_end"] = time.time()
            self.connections_dict[self.conn_id]["ip_o"] = self.conn_id[0]
            self.connections_dict[self.conn_id]["port_o"] = self.conn_id[1]
            self.connections_dict[self.conn_id]["flow_id"] = ":".join(self.conn_id)
            self.connections_dict[self.conn_id]["service_id"] = self.get_service_id2(CONTEXT_ID)
            self.connections_dict[self.conn_id]["endpoint_id"] = self.get_endpoint_id(CONTEXT_ID)
            self.connections_dict[self.conn_id]["protocol"] = "TCP"
            self.connections_dict[self.conn_id]["ip_d"] = self.conn_id[2]
            self.connections_dict[self.conn_id]["port_d"] = self.conn_id[3]


    def check_if_connection_is_attack(self):
        if self.conn_id[0] in ATTACK_IPS or self.conn_id[2] in ATTACK_IPS:
            LOGGER.info("Attack detected. Origin: {0}, destination: {1}".format(self.conn_id[0], self.conn_id[2]))


    def create_cad_features(self):
        self.cad_features = {
            "features": self.new_connections[self.conn_id][0:10],
            "connection_metadata": {
                "ip_o": self.connections_dict[self.conn_id]["ip_o"],
                "port_o": self.connections_dict[self.conn_id]["port_o"],
                "ip_d": self.connections_dict[self.conn_id]["ip_d"],
                "port_d": self.connections_dict[self.conn_id]["port_d"],
                "flow_id": self.connections_dict[self.conn_id]["flow_id"],
                "service_id": self.connections_dict[self.conn_id]["service_id"],
                "endpoint_id": self.connections_dict[self.conn_id]["endpoint_id"],
                "protocol": self.connections_dict[self.conn_id]["protocol"],
                "time_start": self.connections_dict[self.conn_id]["time_start"],
                "time_end": self.connections_dict[self.conn_id]["time_end"],
            },
        }


    async def send_batch_async(self, metrics_list_pb):
        loop = asyncio.get_running_loop()

        # Create metrics batch
        metrics_batch = L3CentralizedattackdetectorBatchInput()
        metrics_batch.metrics.extend(metrics_list_pb)

        # Send batch
        future = loop.run_in_executor(
            None, self.cad.AnalyzeBatchConnectionStatistics, metrics_batch
        )

        try:
            await future
        except Exception as e:
            LOGGER.error(f"Error sending batch: {e}")


    async def send_data(self, metrics_list_pb, send_data_times):
        # Send to CAD
        if SEND_DATA_IN_BATCHES:
            if len(metrics_list_pb) == BATCH_SIZE:
                send_data_time_start = time.time()
                await self.send_batch_async(metrics_list_pb)
                metrics_list_pb = []

                send_data_time_end = time.time()
                send_data_time = send_data_time_end - send_data_time_start
                send_data_times = np.append(send_data_times, send_data_time)

        else:
            send_data_time_start = time.time()
            self.cad.AnalyzeConnectionStatistics(metrics_list_pb[-1])

            send_data_time_end = time.time()
            send_data_time = send_data_time_end - send_data_time_start
            send_data_times = np.append(send_data_times, send_data_time)

        return metrics_list_pb, send_data_times


    async def process_traffic(self):
        LOGGER.info("Loading Tstat log file...")
        logfile = open(self.load_file(), "r")

        LOGGER.info("Following Tstat log file...")
        loglines = self.follow(logfile, 5)

        process_time = []
        num_lines = 0

        send_data_times = np.array([])
        metrics_list_pb = []

        LOGGER.info("Starting to process data...")
        
        index = 0
        while True:
            line = next(loglines, None)

            while line == None:
                LOGGER.info("Waiting for new data...")
                time.sleep(1 / 100)
                line = next(loglines, None)
            if index == 0 and IGNORE_FIRST_LINE_TSTAT:
                index = index + 1
                continue
            if STOP:
                break

            num_lines += 1
            start = time.time()
            line_id = line.split(" ")
            self.conn_id = (line_id[0], line_id[1], line_id[14], line_id[15])
            self.new_connections[self.conn_id] = self.process_line(line)

            self.check_if_connection_is_attack()

            self.insert_connection()

            self.create_cad_features()
            
            self.check_types()
            
            connection_metadata = ConnectionMetadata(**self.cad_features["connection_metadata"])
            metrics = L3CentralizedattackdetectorMetrics()

            for feature in self.cad_features["features"]:
                feature_obj = Feature()
                feature_obj.feature = feature
                metrics.features.append(feature_obj)

            metrics.connection_metadata.CopyFrom(connection_metadata)
            metrics_list_pb.append(metrics)

            metrics_list_pb, send_data_times = await self.send_data(metrics_list_pb, send_data_times)

            index = index + 1
            
            process_time.append(time.time() - start)
            
            if num_lines % 10 == 0:
                LOGGER.info(f"Number of lines: {num_lines} - Average processing time: {sum(process_time) / num_lines}")