Commit ce468995 authored by Amit Karamchandani Batra's avatar Amit Karamchandani Batra
Browse files

Cleanup of the DAD service implementation

parent 3f081d8d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
37.187.95.110,91.121.140.167,94.23.23.52,94.23.247.226,149.202.83.171
 No newline at end of file
+53 −60
Original line number Diff line number Diff line
@@ -13,24 +13,22 @@
# limitations under the License.

import asyncio
import grpc
import logging
import numpy as np
import os
import signal
import time
from sys import stdout
from common.proto.context_pb2 import (
    Empty,
    ServiceTypeEnum,
    ContextId,
)

import grpc
import numpy as np

from common.proto.context_pb2 import ContextId, Empty, ServiceTypeEnum
from common.proto.context_pb2_grpc import ContextServiceStub
from common.proto.l3_centralizedattackdetector_pb2 import (
    L3CentralizedattackdetectorMetrics,
    L3CentralizedattackdetectorBatchInput,
    ConnectionMetadata,
    Feature,
    L3CentralizedattackdetectorBatchInput,
    L3CentralizedattackdetectorMetrics,
)
from common.proto.l3_centralizedattackdetector_pb2_grpc import L3CentralizedattackdetectorStub

@@ -42,8 +40,12 @@ consoleHandler = logging.StreamHandler(stdout)
consoleHandler.setFormatter(logFormatter)
LOGGER.addHandler(consoleHandler)

# Define constants
TSTAT_DIR_NAME = "piped/"
CENTRALIZED_ATTACK_DETECTOR = "192.168.165.78:10001"
CONTROLLER_IP = "192.168.165.78"  # Change this to the IP of the controller
CONTEXT_ID = "admin"  # Change this to the context ID to be used
CONTEXT_CHANNEL = f"{CONTROLLER_IP}:1010"  # Change this to the IP of the controller
CENTRALIZED_ATTACK_DETECTOR = f"{CONTROLLER_IP}:10001"
JSON_BLANK = {
    "ip_o": "",  # Client IP
    "port_o": "",  # Client port
@@ -54,18 +56,14 @@ JSON_BLANK = {
    "time_start": 0.0,  # Start of connection
    "time_end": 0.0,  # Time of last packet
}

STOP = False
IGNORE_FIRST_LINE_TSTAT = True

CONTEXT_ID = "admin"
CONTEXT_CHANNEL = "192.168.165.78:1010"
PROFILING = False
SEND_DATA_IN_BATCHES = False
BATCH_SIZE = 10
ATTACK_IPS = ["37.187.95.110", "91.121.140.167", "94.23.23.52", "94.23.247.226", "149.202.83.171"]

class l3_distributedattackdetector():

class l3_distributedattackdetector:
    def __init__(self):
        LOGGER.info("Creating Distributed Attack Detector")
        
@@ -77,6 +75,8 @@ class l3_distributedattackdetector():
        self.connections_dict = {}  # Dict for storing ALL data
        self.new_connections = {}  # Dict for storing NEW data

        self.known_attack_ips = self.read_kwnown_attack_ips()
        
        signal.signal(signal.SIGINT, self.handler)

        with grpc.insecure_channel(CENTRALIZED_ATTACK_DETECTOR) as channel:
@@ -89,6 +89,14 @@ class l3_distributedattackdetector():

            asyncio.run(self.process_traffic())
    
    def read_kwnown_attack_ips(self):
        known_attack_ips = []
        
        # open known attack ips csv file
        with open("known_attack_ips.csv", "r") as f:
            known_attack_ips = f.read().splitlines()
        
        return known_attack_ips

    def handler(self):
        if STOP:
@@ -98,33 +106,34 @@ class l3_distributedattackdetector():

        LOGGER.info("Gracefully stopping...")

    def follow(self, thefile, time_sleep):
    def follow(self, file, time_sleep):
        """
        Generator function that yields new lines in a file
        It reads the logfie (the opened file)
        """

        # seek the end of the file
        # thefile.seek(0, os.SEEK_END)
        # file.seek(0, os.SEEK_END)

        trozo = ""
        chunk = ""

        # start infinite loop
        # start an infinite loop
        while True:
            # read last line of file
            line = thefile.readline()
            line = file.readline()

            # sleep if file hasn't been updated
            if not line:
                time.sleep(time_sleep)
                continue

            if line[-1] != "\n":
                trozo += line
                chunk += line
            else:
                if trozo != "":
                    line = trozo + line
                    trozo = ""
                yield line
                if chunk != "":
                    line = chunk + line
                    chunk = ""

                yield line

    def load_file(self, dirname=TSTAT_DIR_NAME):  # - Client side -
        while True:
@@ -141,7 +150,6 @@ class l3_distributedattackdetector():
                LOGGER.info("No Tstat directory!")
                time.sleep(5)


    def process_line(self, line):
        """
        - Preprocessing before a message per line
@@ -161,7 +169,6 @@ class l3_distributedattackdetector():

        return values


    def get_service_ids(self, context_id_str):
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
            stub = ContextServiceStub(channel)
@@ -169,7 +176,6 @@ class l3_distributedattackdetector():
            context_id.context_uuid.uuid = context_id_str
            return stub.ListServiceIds(context_id)


    def get_services(self, context_id_str):
        with grpc.insecure_channel(CONTEXT_CHANNEL) as channel:
            stub = ContextServiceStub(channel)
@@ -177,7 +183,6 @@ class l3_distributedattackdetector():
            context_id.context_uuid.uuid = context_id_str
            return stub.ListServices(context_id)


    def get_service_id(self, context_id):
        service_id_list = self.get_service_ids(context_id)
        service_id = None
@@ -190,7 +195,6 @@ class l3_distributedattackdetector():

        return service_id


    def get_service_id2(self, context_id):
        service_list = self.get_services(context_id)
        service_id = None
@@ -202,7 +206,6 @@ class l3_distributedattackdetector():
                pass
        return service_id


    def get_endpoint_id(self, context_id):
        service_list = self.get_services(context_id)
        endpoint_id = None
@@ -212,11 +215,9 @@ class l3_distributedattackdetector():
                break
        return endpoint_id


    def get_features_ids(self):
        return self.cad.GetFeaturesIds(Empty()).auto_features


    def check_types(self):
        for feature in self.cad_features["features"]:
            assert isinstance(feature, float)
@@ -230,7 +231,6 @@ class l3_distributedattackdetector():
        assert isinstance(self.cad_features["connection_metadata"]["time_start"], float)
        assert isinstance(self.cad_features["connection_metadata"]["time_end"], float)


    def insert_connection(self):
        try:
            self.connections_dict[self.conn_id]["time_end"] = time.time()
@@ -247,12 +247,10 @@ class l3_distributedattackdetector():
            self.connections_dict[self.conn_id]["ip_d"] = self.conn_id[2]
            self.connections_dict[self.conn_id]["port_d"] = self.conn_id[3]


    def check_if_connection_is_attack(self):
        if self.conn_id[0] in ATTACK_IPS or self.conn_id[2] in ATTACK_IPS:
        if self.conn_id[0] in self.known_attack_ips or self.conn_id[2] in self.known_attack_ips:
            LOGGER.info("Attack detected. Origin: {0}, destination: {1}".format(self.conn_id[0], self.conn_id[2]))


    def create_cad_features(self):
        self.cad_features = {
            "features": self.new_connections[self.conn_id][0:10],
@@ -270,7 +268,6 @@ class l3_distributedattackdetector():
            },
        }


    async def send_batch_async(self, metrics_list_pb):
        loop = asyncio.get_running_loop()

@@ -279,16 +276,13 @@ class l3_distributedattackdetector():
        metrics_batch.metrics.extend(metrics_list_pb)

        # Send batch
        future = loop.run_in_executor(
            None, self.cad.AnalyzeBatchConnectionStatistics, metrics_batch
        )
        future = loop.run_in_executor(None, self.cad.AnalyzeBatchConnectionStatistics, metrics_batch)

        try:
            await future
        except Exception as e:
            LOGGER.error(f"Error sending batch: {e}")


    async def send_data(self, metrics_list_pb, send_data_times):
        # Send to CAD
        if SEND_DATA_IN_BATCHES:
@@ -311,7 +305,6 @@ class l3_distributedattackdetector():

        return metrics_list_pb, send_data_times


    async def process_traffic(self):
        LOGGER.info("Loading Tstat log file...")
        logfile = open(self.load_file(), "r")