Commit b81596f5 authored by Amit Karamchandani Batra's avatar Amit Karamchandani Batra
Browse files

Cleanup of the DAD service implementation

parent b3c0e922
Loading
Loading
Loading
Loading
+10 −4
Original line number Original line Diff line number Diff line
@@ -177,6 +177,7 @@ class l3_distributedattackdetector:
            stub = ContextServiceStub(channel)
            stub = ContextServiceStub(channel)
            context_id = ContextId()
            context_id = ContextId()
            context_id.context_uuid.uuid = context_id_str
            context_id.context_uuid.uuid = context_id_str
            
            return stub.ListServiceIds(context_id)
            return stub.ListServiceIds(context_id)


    def get_services(self, context_id_str):
    def get_services(self, context_id_str):
@@ -184,26 +185,31 @@ class l3_distributedattackdetector:
            stub = ContextServiceStub(channel)
            stub = ContextServiceStub(channel)
            context_id = ContextId()
            context_id = ContextId()
            context_id.context_uuid.uuid = context_id_str
            context_id.context_uuid.uuid = context_id_str
            
            return stub.ListServices(context_id)
            return stub.ListServices(context_id)


    def get_service_id(self, context_id):
    def get_service_id(self, context_id):
        service_list = self.get_services(context_id)
        service_list = self.get_services(context_id)
        service_id = None
        service_id = None
        
        for s in service_list.services:
        for s in service_list.services:
            if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM:
            if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM:
                service_id = s.service_id
                service_id = s.service_id
                break
                break
            else:
            else:
                pass
                pass
            
        return service_id
        return service_id


    def get_endpoint_id(self, context_id):
    def get_endpoint_id(self, context_id):
        service_list = self.get_services(context_id)
        service_list = self.get_services(context_id)
        endpoint_id = None
        endpoint_id = None
        
        for s in service_list.services:
        for s in service_list.services:
            if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM:
            if s.service_type == ServiceTypeEnum.SERVICETYPE_L3NM:
                endpoint_id = s.service_endpoint_ids[0]
                endpoint_id = s.service_endpoint_ids[0]
                break
                break
            
        return endpoint_id
        return endpoint_id


    def get_features_ids(self):
    def get_features_ids(self):
@@ -275,7 +281,7 @@ class l3_distributedattackdetector:
            LOGGER.error(f"Error sending batch: {e}")
            LOGGER.error(f"Error sending batch: {e}")


    async def send_data(self, metrics_list_pb, send_data_times):
    async def send_data(self, metrics_list_pb, send_data_times):
        # Send to CAD
        # Send data to CAD
        if SEND_DATA_IN_BATCHES:
        if SEND_DATA_IN_BATCHES:
            if len(metrics_list_pb) == BATCH_SIZE:
            if len(metrics_list_pb) == BATCH_SIZE:
                send_data_time_start = time.time()
                send_data_time_start = time.time()
@@ -319,24 +325,24 @@ class l3_distributedattackdetector:
                LOGGER.info("Waiting for new data...")
                LOGGER.info("Waiting for new data...")
                time.sleep(1 / 100)
                time.sleep(1 / 100)
                line = next(loglines, None)
                line = next(loglines, None)
                
            if index == 0 and IGNORE_FIRST_LINE_TSTAT:
            if index == 0 and IGNORE_FIRST_LINE_TSTAT:
                index = index + 1
                index = index + 1
                continue
                continue
            
            if STOP:
            if STOP:
                break
                break


            num_lines += 1
            num_lines += 1
            start = time.time()
            start = time.time()
            line_id = line.split(" ")
            line_id = line.split(" ")
            
            self.conn_id = (line_id[0], line_id[1], line_id[14], line_id[15])
            self.conn_id = (line_id[0], line_id[1], line_id[14], line_id[15])
            self.new_connections[self.conn_id] = self.process_line(line)
            self.new_connections[self.conn_id] = self.process_line(line)


            self.check_if_connection_is_attack()
            self.check_if_connection_is_attack()

            self.insert_connection()
            self.insert_connection()

            self.create_cad_features()
            self.create_cad_features()

            self.check_types()
            self.check_types()


            connection_metadata = ConnectionMetadata(**self.cad_features["connection_metadata"])
            connection_metadata = ConnectionMetadata(**self.cad_features["connection_metadata"])