Commit ce468995 authored by Amit Karamchandani Batra's avatar Amit Karamchandani Batra
Browse files

Cleanup of the DAD service implementation

parent 3f081d8d
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
37.187.95.110,91.121.140.167,94.23.23.52,94.23.247.226,149.202.83.171
 No newline at end of file
+53 −60
Original line number Original line Diff line number Diff line
@@ -13,24 +13,22 @@
# limitations under the License.
# limitations under the License.


import asyncio
import asyncio
import grpc
import logging
import logging
import numpy as np
import os
import os
import signal
import signal
import time
import time
from sys import stdout
from sys import stdout
from common.proto.context_pb2 import (

    Empty,
import grpc
    ServiceTypeEnum,
import numpy as np
    ContextId,

)
from common.proto.context_pb2 import ContextId, Empty, ServiceTypeEnum
from common.proto.context_pb2_grpc import ContextServiceStub
from common.proto.context_pb2_grpc import ContextServiceStub
from common.proto.l3_centralizedattackdetector_pb2 import (
from common.proto.l3_centralizedattackdetector_pb2 import (
    L3CentralizedattackdetectorMetrics,
    L3CentralizedattackdetectorBatchInput,
    ConnectionMetadata,
    ConnectionMetadata,
    Feature,
    Feature,
    L3CentralizedattackdetectorBatchInput,
    L3CentralizedattackdetectorMetrics,
)
)
from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorStub
from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorStub


@@ -42,8 +40,12 @@ consoleHandler = logging.StreamHandler(stdout)
consoleHandler.setFormatter(logFormatter)
consoleHandler.setFormatter(logFormatter)
LOGGER.addHandler(consoleHandler)
LOGGER.addHandler(consoleHandler)


# Define constants
TSTAT_DIR_NAME = "piped/"
TSTAT_DIR_NAME = "piped/"
CENTRALIZED_ATTACK_DETECTOR = "192.168.165.78:10001"
CONTROLLER_IP = "192.168.165.78"  # Change this to the IP of the controller
CONTEXT_ID = "admin"  # Change this to the context ID to be used
CONTEXT_CHANNEL = f"{CONTROLLER_IP}:1010"  # Change this to the IP of the controller
CENTRALIZED_ATTACK_DETECTOR = f"{CONTROLLER_IP}:10001"
JSON_BLANK = {
JSON_BLANK = {
    "ip_o": "",  # Client IP
    "ip_o": "",  # Client IP
    "port_o": "",  # Client port
    "port_o": "",  # Client port
@@ -54,18 +56,14 @@ JSON_BLANK = {
    "time_start": 0.0,  # Start of connection
    "time_start": 0.0,  # Start of connection
    "time_end": 0.0,  # Time of last packet
    "time_end": 0.0,  # Time of last packet
}
}

STOP = False
STOP = False
IGNORE_FIRST_LINE_TSTAT = True
IGNORE_FIRST_LINE_TSTAT = True

CONTEXT_ID = "admin"
CONTEXT_CHANNEL = "192.168.165.78:1010"
PROFILING = False
PROFILING = False
SEND_DATA_IN_BATCHES = False
SEND_DATA_IN_BATCHES = False
BATCH_SIZE = 10
BATCH_SIZE = 10
ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"]


class l3_distributedattackdetector():

class l3_distributedattackdetector:
    def __init__(self):
    def __init__(self):
        LOGGER.info("Creating Distributed Attack Detector")
        LOGGER.info("Creating Distributed Attack Detector")
        
        
@@ -77,6 +75,8 @@ class l3_distributedattackdetector():
        self.connections_dict = {}  # Dict for storing ALL data
        self.connections_dict = {}  # Dict for storing ALL data
        self.new_connections = {}  # Dict for storing NEW data
        self.new_connections = {}  # Dict for storing NEW data


        self.known_attack_ips = self.read_kwnown_attack_ips()
        
        signal.signal(signal.SIGINT, self.handler)
        signal.signal(signal.SIGINT, self.handler)


        with grpc.insecure_channel(CENTRALIZED_ATTACK_DETECTOR) as channel:
        with grpc.insecure_channel(CENTRALIZED_ATTACK_DETECTOR) as channel:
@@ -89,6 +89,14 @@ class l3_distributedattackdetector():


            asyncio.run(self.process_traffic())
            asyncio.run(self.process_traffic())
    
    
    def read_kwnown_attack_ips(self):
        known_attack_ips = []
        
        # open known attack ips csv file
        with open("known_attack_ips.csv", "r") as f:
            known_attack_ips = f.read().splitlines()
        
        return known_attack_ips


    def handler(self):
    def handler(self):
        if STOP:
        if STOP:
@@ -98,33 +106,34 @@ class l3_distributedattackdetector():


        LOGGER.info("Gracefully stopping...")
        LOGGER.info("Gracefully stopping...")


    def follow(self, thefile, time_sleep):
    def follow(self, file, time_sleep):
        """
        """
        Generator function that yields new lines in a file
        Generator function that yields new lines in a file
        It reads the logfie (the opened file)
        """
        """

        # seek the end of the file
        # seek the end of the file
        # thefile.seek(0, os.SEEK_END)
        # file.seek(0, os.SEEK_END)


        trozo = ""
        chunk = ""


        # start infinite loop
        # start an infinite loop
        while True:
        while True:
            # read last line of file
            # read last line of file
            line = thefile.readline()
            line = file.readline()


            # sleep if file hasn't been updated
            # sleep if file hasn't been updated
            if not line:
            if not line:
                time.sleep(time_sleep)
                time.sleep(time_sleep)
                continue
                continue

            if line[-1] != "\n":
            if line[-1] != "\n":
                trozo += line
                chunk += line
            else:
            else:
                if trozo != "":
                if chunk != "":
                    line = trozo + line
                    line = chunk + line
                    trozo = ""
                    chunk = ""
                yield line


                yield line


    def load_file(self, dirname=TSTAT_DIR_NAME):  # - Client side -
    def load_file(self, dirname=TSTAT_DIR_NAME):  # - Client side -
        while True:
        while True:
@@ -141,7 +150,6 @@ class l3_distributedattackdetector():
                LOGGER.info("No Tstat directory!")
                LOGGER.info("No Tstat directory!")
                time.sleep(5)
                time.sleep(5)



    def process_line(self, line):
    def process_line(self, line):
        """
        """
        - Preprocessing before a message per line
        - Preprocessing before a message per line
@@ -161,7 +169,6 @@ class l3_distributedattackdetector():


        return values
        return values



    def get_service_ids(self, context_id_str):
    def get_service_ids(self, context_id_str):
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
            stub = ContextServiceStub(channel)
            stub = ContextServiceStub(channel)
@@ -169,7 +176,6 @@ class l3_distributedattackdetector():
            context_id.context_uuid.uuid = context_id_str
            context_id.context_uuid.uuid = context_id_str
            return stub.ListServiceIds(context_id)
            return stub.ListServiceIds(context_id)



    def get_services(self, context_id_str):
    def get_services(self, context_id_str):
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
            stub = ContextServiceStub(channel)
            stub = ContextServiceStub(channel)
@@ -177,7 +183,6 @@ class l3_distributedattackdetector():
            context_id.context_uuid.uuid = context_id_str
            context_id.context_uuid.uuid = context_id_str
            return stub.ListServices(context_id)
            return stub.ListServices(context_id)



    def get_service_id(self, context_id):
    def get_service_id(self, context_id):
        service_id_list = self.get_service_ids(context_id)
        service_id_list = self.get_service_ids(context_id)
        service_id = None
        service_id = None
@@ -190,7 +195,6 @@ class l3_distributedattackdetector():


        return service_id
        return service_id



    def get_service_id2(self, context_id):
    def get_service_id2(self, context_id):
        service_list = self.get_services(context_id)
        service_list = self.get_services(context_id)
        service_id = None
        service_id = None
@@ -202,7 +206,6 @@ class l3_distributedattackdetector():
                pass
                pass
        return service_id
        return service_id



    def get_endpoint_id(self, context_id):
    def get_endpoint_id(self, context_id):
        service_list = self.get_services(context_id)
        service_list = self.get_services(context_id)
        endpoint_id = None
        endpoint_id = None
@@ -212,11 +215,9 @@ class l3_distributedattackdetector():
                break
                break
        return endpoint_id
        return endpoint_id



    def get_features_ids(self):
    def get_features_ids(self):
        return self.cad.GetFeaturesIds(Empty()).auto_features
        return self.cad.GetFeaturesIds(Empty()).auto_features



    def check_types(self):
    def check_types(self):
        for feature in self.cad_features["features"]:
        for feature in self.cad_features["features"]:
            assert isinstance(feature, float)
            assert isinstance(feature, float)
@@ -230,7 +231,6 @@ class l3_distributedattackdetector():
        assert isinstance(self.cad_features["connection_metadata"]["time_start"], float)
        assert isinstance(self.cad_features["connection_metadata"]["time_start"], float)
        assert isinstance(self.cad_features["connection_metadata"]["time_end"], float)
        assert isinstance(self.cad_features["connection_metadata"]["time_end"], float)



    def insert_connection(self):
    def insert_connection(self):
        try:
        try:
            self.connections_dict[self.conn_id]["time_end"] = time.time()
            self.connections_dict[self.conn_id]["time_end"] = time.time()
@@ -247,12 +247,10 @@ class l3_distributedattackdetector():
            self.connections_dict[self.conn_id]["ip_d"] = self.conn_id[2]
            self.connections_dict[self.conn_id]["ip_d"] = self.conn_id[2]
            self.connections_dict[self.conn_id]["port_d"] = self.conn_id[3]
            self.connections_dict[self.conn_id]["port_d"] = self.conn_id[3]



    def check_if_connection_is_attack(self):
    def check_if_connection_is_attack(self):
        if self.conn_id[0] in ATTACK_IPS or self.conn_id[2] in ATTACK_IPS:
        if self.conn_id[0] in self.known_attack_ips or self.conn_id[2] in self.known_attack_ips:
            LOGGER.info("Attack detected. Origin: {0}, destination: {1}".format(self.conn_id[0], self.conn_id[2]))
            LOGGER.info("Attack detected. Origin: {0}, destination: {1}".format(self.conn_id[0], self.conn_id[2]))



    def create_cad_features(self):
    def create_cad_features(self):
        self.cad_features = {
        self.cad_features = {
            "features": self.new_connections[self.conn_id][0:10],
            "features": self.new_connections[self.conn_id][0:10],
@@ -270,7 +268,6 @@ class l3_distributedattackdetector():
            },
            },
        }
        }



    async def send_batch_async(self, metrics_list_pb):
    async def send_batch_async(self, metrics_list_pb):
        loop = asyncio.get_running_loop()
        loop = asyncio.get_running_loop()


@@ -279,16 +276,13 @@ class l3_distributedattackdetector():
        metrics_batch.metrics.extend(metrics_list_pb)
        metrics_batch.metrics.extend(metrics_list_pb)


        # Send batch
        # Send batch
        future = loop.run_in_executor(
        future = loop.run_in_executor(None, self.cad.AnalyzeBatchConnectionStatistics, metrics_batch)
            None, self.cad.AnalyzeBatchConnectionStatistics, metrics_batch
        )


        try:
        try:
            await future
            await future
        except Exception as e:
        except Exception as e:
            LOGGER.error(f"Error sending batch: {e}")
            LOGGER.error(f"Error sending batch: {e}")



    async def send_data(self, metrics_list_pb, send_data_times):
    async def send_data(self, metrics_list_pb, send_data_times):
        # Send to CAD
        # Send to CAD
        if SEND_DATA_IN_BATCHES:
        if SEND_DATA_IN_BATCHES:
@@ -311,7 +305,6 @@ class l3_distributedattackdetector():


        return metrics_list_pb, send_data_times
        return metrics_list_pb, send_data_times



    async def process_traffic(self):
    async def process_traffic(self):
        LOGGER.info("Loading Tstat log file...")
        LOGGER.info("Loading Tstat log file...")
        logfile = open(self.load_file(), "r")
        logfile = open(self.load_file(), "r")