From 3b7838c5aa510b80fa2a65551ae44278998754f7 Mon Sep 17 00:00:00 2001 From: ldemarcosm <l.demarcosm@alumnos.upm.es> Date: Tue, 2 Nov 2021 11:55:41 +0100 Subject: [PATCH] Fixing minor issues 2 --- src/l3_attackmitigator/service/__main__.py | 4 ++-- .../service/l3_attackmitigatorService.py | 4 +--- .../l3_attackmitigatorServiceServicerImpl.py | 4 +--- src/l3_attackmitigator/tests/test_unitary.py | 14 ++++---------- .../service/__main__.py | 4 +--- .../service/l3_centralizedattackdetectorService.py | 4 +--- ...centralizedattackdetectorServiceServicerImpl.py | 8 +++----- ...distributedattackdetectorServiceServicerImpl.py | 5 +++-- 8 files changed, 16 insertions(+), 31 deletions(-) diff --git a/src/l3_attackmitigator/service/__main__.py b/src/l3_attackmitigator/service/__main__.py index 394d3c3f9..3843c6c18 100644 --- a/src/l3_attackmitigator/service/__main__.py +++ b/src/l3_attackmitigator/service/__main__.py @@ -34,11 +34,11 @@ def main(): start_http_server(metrics_port) # Get database instance - database = get_database() + #database = get_database() # Starting l3_attackmitigator service grpc_service = l3_attackmitigatorService( - database, port=service_port, max_workers=max_workers, grace_period=grace_period) + port=service_port, max_workers=max_workers, grace_period=grace_period) grpc_service.start() # Wait for Ctrl+C or termination signal diff --git a/src/l3_attackmitigator/service/l3_attackmitigatorService.py b/src/l3_attackmitigator/service/l3_attackmitigatorService.py index bfad98200..254457c70 100644 --- a/src/l3_attackmitigator/service/l3_attackmitigatorService.py +++ b/src/l3_attackmitigator/service/l3_attackmitigatorService.py @@ -23,13 +23,11 @@ LOGGER = logging.getLogger(__name__) class l3_attackmitigatorService: def __init__( self, - database, address=BIND_ADDRESS, port=GRPC_SERVICE_PORT, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD, ): - self.database = database self.address = address self.port = port self.endpoint = None @@ -52,7 +50,7 @@ class l3_attackmitigatorService: self.server = grpc.server(self.pool) # , interceptors=(tracer_interceptor,)) self.l3_attackmitigator_servicer = ( - l3_attackmitigatorServiceServicerImpl(self.database) + l3_attackmitigatorServiceServicerImpl() ) add_L3AttackmitigatorServicer_to_server( self.l3_attackmitigator_servicer, self.server diff --git a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py index 0e2a64c1e..660d790bf 100644 --- a/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py +++ b/src/l3_attackmitigator/service/l3_attackmitigatorServiceServicerImpl.py @@ -27,10 +27,8 @@ LAST_VALUE = -1 class l3_attackmitigatorServiceServicerImpl(L3AttackmitigatorServicer): - def __init__(self, database: Database): + def __init__(self): LOGGER.debug("Creating Servicer...") - self.database = database - LOGGER.debug("Servicer Created") class Mitigator(L3AttackmitigatorServicer): def send_output(self, request, context): diff --git a/src/l3_attackmitigator/tests/test_unitary.py b/src/l3_attackmitigator/tests/test_unitary.py index 4acecd94e..54579472f 100644 --- a/src/l3_attackmitigator/tests/test_unitary.py +++ b/src/l3_attackmitigator/tests/test_unitary.py @@ -25,15 +25,9 @@ LOGGER.setLevel(logging.DEBUG) @pytest.fixture(scope='session') -def database(): - _database = get_database(engine=DatabaseEngineEnum.INMEMORY) - return _database - - -@pytest.fixture(scope='session') -def l3_attackmitigator_service(database): +def l3_attackmitigator_service(): _service = l3_attackmitigatorService( - database, port=port, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD) + port=port, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD) _service.start() yield _service _service.stop() @@ -49,10 +43,10 @@ def test_demo(): print('Demo Test') pass -def test_grpc_server(database): +def test_grpc_server(): print('Starting AM') _service = l3_attackmitigatorService( - database, port=port, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD) + port=port, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD) p1 = multiprocessing.Process(target=_service.start, args=()) #_service.start() p1.start() diff --git a/src/l3_centralizedattackdetector/service/__main__.py b/src/l3_centralizedattackdetector/service/__main__.py index eac7b5f22..1e593111b 100644 --- a/src/l3_centralizedattackdetector/service/__main__.py +++ b/src/l3_centralizedattackdetector/service/__main__.py @@ -33,12 +33,10 @@ def main(): # Start metrics server start_http_server(metrics_port) - # Get database instance - database = get_database() # Starting l3_centralizedattackdetector service grpc_service = l3_centralizedattackdetectorService( - database, port=service_port, max_workers=max_workers, grace_period=grace_period) + port=service_port, max_workers=max_workers, grace_period=grace_period) grpc_service.start() # Wait for Ctrl+C or termination signal diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorService.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorService.py index 7c71fc77f..1bcb50e1e 100644 --- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorService.py +++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorService.py @@ -23,13 +23,11 @@ LOGGER = logging.getLogger(__name__) class l3_centralizedattackdetectorService: def __init__( self, - database, address=BIND_ADDRESS, port=GRPC_SERVICE_PORT, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD, ): - self.database = database self.address = address self.port = port self.endpoint = None @@ -52,7 +50,7 @@ class l3_centralizedattackdetectorService: self.server = grpc.server(self.pool) # , interceptors=(tracer_interceptor,)) self.l3_centralizedattackdetector_servicer = ( - l3_centralizedattackdetectorServiceServicerImpl(self.database) + l3_centralizedattackdetectorServiceServicerImpl() ) add_L3CentralizedattackdetectorServicer_to_server( self.l3_centralizedattackdetector_servicer, self.server diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py index 0e137f424..4ff4f58ee 100644 --- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py @@ -25,7 +25,7 @@ from l3_centralizedattackdetector.proto.l3_attackmitigator_pb2 import ( Output, ) from l3_centralizedattackdetector.proto.l3_attackmitigator_pb2_grpc import ( - l3_attackmitigatorStub, + L3AttackmitigatorStub, ) LOGGER = logging.getLogger(__name__) @@ -36,10 +36,8 @@ INFERENCE_VALUES = [] class l3_centralizedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer): - def __init__(self, database: Database): + def __init__(self): LOGGER.debug("Creating Servicer...") - self.database = database - LOGGER.debug("Servicer Created") class CAD(L3CentralizedattackdetectorServicer): def __init__(self, ml_model): @@ -61,7 +59,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto with grpc.insecure_channel("localhost:10002") as channel: stub = l3_attackmitigatorStub(channel) print("Sending to mitigator...") - response = stub.send_output(output) + response = stub.SendOutput(output) print("Sent output to mitigator and received: ", response.message) # RETURN "OK" TO THE CALLER diff --git a/src/l3_distributedattackdetector/service/l3_distributedattackdetectorServiceServicerImpl.py b/src/l3_distributedattackdetector/service/l3_distributedattackdetectorServiceServicerImpl.py index e90fe3128..0f8cc3d31 100644 --- a/src/l3_distributedattackdetector/service/l3_distributedattackdetectorServiceServicerImpl.py +++ b/src/l3_distributedattackdetector/service/l3_distributedattackdetectorServiceServicerImpl.py @@ -11,12 +11,13 @@ from l3_distributedattackdetector.proto.l3_centralizedattackdetector_pb2 import ) from l3_distributedattackdetector.proto.l3_centralizedattackdetector_pb2_grpc import ( L3CentralizedattackdetectorStub, + L3CentralizedattackdetectorServicer ) LOGGER = logging.getLogger(__name__) TSTAT_DIR_NAME = "piped/" -class l3_distributedattackdetectorServiceServicerImpl(): +class l3_distributedattackdetectorServiceServicerImpl(L3CentralizedattackdetectorServicer): def __init__(self): LOGGER.debug("Creating Servicer...") @@ -116,7 +117,7 @@ class l3_distributedattackdetectorServiceServicerImpl(): def open_channel(self, input_information): with grpc.insecure_channel("localhost:10001") as channel: stub = L3CentralizedattackdetectorStub(channel) - response = stub.send_input( + response = stub.SendInput( ModelInput(**input_information)) logging.debug("Inferencer send_input sent and received: ", response.message) -- GitLab