diff --git a/.gitignore b/.gitignore index b0d99f8bb5fd4e09019732c5c99bbc5163bc617e..53f8e67287f1e43cf05af70f8252eb4df3576be2 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ __pycache__/ # C extensions *.so +# workspace configurations +pyproject.toml + # Distribution / packaging .Python build/ diff --git a/deploy.sh b/deploy.sh index 3ae4bbcaa0e7d665ff7f60e839f90e38a936feaf..bcf162c2e3aa3baf8a7b3c6ac9fbc69a46ddd962 100755 --- a/deploy.sh +++ b/deploy.sh @@ -58,31 +58,12 @@ kubectl delete namespace $TFS_K8S_NAMESPACE kubectl create namespace $TFS_K8S_NAMESPACE printf "\n" -# if [[ "$TFS_COMPONENTS" == *"monitoring"* ]]; then -# echo "Creating secrets for InfluxDB..." -# #TODO: make sure to change this when having a production deployment -# kubectl create secret generic influxdb-secrets --namespace=$TFS_K8S_NAMESPACE \ -# --from-literal=INFLUXDB_DB="monitoring" --from-literal=INFLUXDB_ADMIN_USER="teraflow" \ -# --from-literal=INFLUXDB_ADMIN_PASSWORD="teraflow" --from-literal=INFLUXDB_HTTP_AUTH_ENABLED="True" -# kubectl create secret generic monitoring-secrets --namespace=$TFS_K8S_NAMESPACE \ -# --from-literal=INFLUXDB_DATABASE="monitoring" --from-literal=INFLUXDB_USER="teraflow" \ -# --from-literal=INFLUXDB_PASSWORD="teraflow" --from-literal=INFLUXDB_HOSTNAME="localhost" -# printf "\n" -# fi - echo "Deploying components and collecting environment variables..." ENV_VARS_SCRIPT=tfs_runtime_env_vars.sh echo "# Environment variables for TeraFlowSDN deployment" > $ENV_VARS_SCRIPT PYTHONPATH=$(pwd)/src echo "export PYTHONPATH=${PYTHONPATH}" >> $ENV_VARS_SCRIPT -# more info: https://www.containiq.com/post/deploy-redis-cluster-on-kubernetes -# generating password for Redis -REDIS_PASSWORD=`uuidgen` -kubectl create secret generic redis-secrets --namespace=$TFS_K8S_NAMESPACE \ - --from-literal=REDIS_PASSWORD=$REDIS_PASSWORD -echo "export REDIS_PASSWORD=${REDIS_PASSWORD}" >> $ENV_VARS_SCRIPT - for COMPONENT in $TFS_COMPONENTS; do echo "Processing '$COMPONENT' component..." IMAGE_NAME="$COMPONENT:$TFS_IMAGE_TAG" diff --git a/manifests/opticalattackdetectorservice.yaml b/manifests/opticalattackdetectorservice.yaml index ee806865552bf03470b6e0daaae1fccd9be1ad0f..2d5d3bf6cca1f1bbee0e14ef288a775563916780 100644 --- a/manifests/opticalattackdetectorservice.yaml +++ b/manifests/opticalattackdetectorservice.yaml @@ -35,11 +35,6 @@ spec: env: - name: LOG_LEVEL value: "DEBUG" - - name: REDIS_PASSWORD - valueFrom: - secretKeyRef: - name: redis-secrets - key: REDIS_PASSWORD readinessProbe: exec: command: ["/bin/grpc_health_probe", "-addr=:10006"] diff --git a/manifests/opticalattackmanagerservice.yaml b/manifests/opticalattackmanagerservice.yaml index a9e60ed8ac3190a441b681af0ce42bf200610941..21309ab8f91e7a10dc22c26c8cfc83fa1a9f4b6b 100644 --- a/manifests/opticalattackmanagerservice.yaml +++ b/manifests/opticalattackmanagerservice.yaml @@ -35,11 +35,6 @@ spec: env: - name: LOG_LEVEL value: "DEBUG" - - name: REDIS_PASSWORD - valueFrom: - secretKeyRef: - name: redis-secrets - key: REDIS_PASSWORD resources: requests: cpu: 250m diff --git a/my_deploy.sh b/my_deploy.sh index 6d2ca46c5cfa983bb3ca4688ae85af515c65473b..2153da8da78082a122fbd62cc85b4d93033da8da 100644 --- a/my_deploy.sh +++ b/my_deploy.sh @@ -7,7 +7,7 @@ export TFS_REGISTRY_IMAGE="http://localhost:32000/tfs/" # interdomain slice pathcomp dlt # dbscanserving opticalattackmitigator opticalcentralizedattackdetector # l3_attackmitigator l3_centralizedattackdetector l3_distributedattackdetector -export TFS_COMPONENTS="context device automation service compute monitoring webui dbscanserving opticalattackmitigator" # opticalattackmanager opticalattackdetector +export TFS_COMPONENTS="context device automation service compute monitoring webui dbscanserving opticalattackmitigator opticalattackmanager opticalattackdetector" # Set the tag you want to use for your images. export TFS_IMAGE_TAG="dev" @@ -16,7 +16,7 @@ export TFS_IMAGE_TAG="dev" export TFS_K8S_NAMESPACE="tfs" # Set additional manifest files to be applied after the deployment -export TFS_EXTRA_MANIFESTS="manifests/nginx_ingress_http.yaml manifests/cachingservice.yaml" +export TFS_EXTRA_MANIFESTS="manifests/nginx_ingress_http.yaml" # Set the neew Grafana admin password export TFS_GRAFANA_PASSWORD="admin123+" diff --git a/proto/generate_code_python.sh b/proto/generate_code_python.sh index b0df357eb079fb2721cffca43465588f7013e341..f28dbe4fde13c56f20a454049ab220a21f63a663 100755 --- a/proto/generate_code_python.sh +++ b/proto/generate_code_python.sh @@ -38,5 +38,8 @@ EOF # Generate Python code python3 -m grpc_tools.protoc -I=./ --python_out=src/python/ --grpc_python_out=src/python/ *.proto +# new line added to generate protobuf for the `grpclib` library +python3 -m grpc_tools.protoc -I=./ --python_out=src/python/asyncio --grpclib_python_out=src/python/asyncio *.proto + # Arrange generated code imports to enable imports from arbitrary subpackages find src/python -type f -iname *.py -exec sed -i -E 's/(import\ .*)_pb2/from . \1_pb2/g' {} \; diff --git a/src/common/Constants.py b/src/common/Constants.py index f18d4384035f2310355d7a16c5a709720b5b07e9..97fbc7b72d4459438502f6052d5b64b90e8dc198 100644 --- a/src/common/Constants.py +++ b/src/common/Constants.py @@ -30,39 +30,45 @@ DEFAULT_HTTP_BIND_ADDRESS = '0.0.0.0' DEFAULT_METRICS_PORT = 9192 # Default context and topology UUIDs -DEFAULT_CONTEXT_UUID = 'admin' +DEFAULT_CONTEXT_UUID = 'admin' DEFAULT_TOPOLOGY_UUID = 'admin' # Default service names class ServiceNameEnum(Enum): - CONTEXT = 'context' - DEVICE = 'device' - SERVICE = 'service' - SLICE = 'slice' - AUTOMATION = 'automation' - POLICY = 'policy' - MONITORING = 'monitoring' - DLT = 'dlt' - COMPUTE = 'compute' - CYBERSECURITY = 'cybersecurity' - INTERDOMAIN = 'interdomain' - PATHCOMP = 'pathcomp' - WEBUI = 'webui' + CONTEXT = 'context' + DEVICE = 'device' + SERVICE = 'service' + SLICE = 'slice' + AUTOMATION = 'automation' + POLICY = 'policy' + MONITORING = 'monitoring' + DLT = 'dlt' + COMPUTE = 'compute' + DBSCANSERVING = 'dbscanserving' + OPTICALATTACKMANAGER = 'opticalattackmanager' + OPTICALATTACKDETECTOR = 'opticalattackdetector' + OPTICALATTACKMITIGATOR = 'opticalattackmitigator' + INTERDOMAIN = 'interdomain' + PATHCOMP = 'pathcomp' + WEBUI = 'webui' # Default gRPC service ports DEFAULT_SERVICE_GRPC_PORTS = { - ServiceNameEnum.CONTEXT .value : 1010, - ServiceNameEnum.DEVICE .value : 2020, - ServiceNameEnum.SERVICE .value : 3030, - ServiceNameEnum.SLICE .value : 4040, - ServiceNameEnum.AUTOMATION .value : 5050, - ServiceNameEnum.POLICY .value : 6060, - ServiceNameEnum.MONITORING .value : 7070, - ServiceNameEnum.DLT .value : 8080, - ServiceNameEnum.COMPUTE .value : 9090, - ServiceNameEnum.CYBERSECURITY.value : 10000, - ServiceNameEnum.INTERDOMAIN .value : 10010, - ServiceNameEnum.PATHCOMP .value : 10020, + ServiceNameEnum.CONTEXT .value : 1010, + ServiceNameEnum.DEVICE .value : 2020, + ServiceNameEnum.SERVICE .value : 3030, + ServiceNameEnum.SLICE .value : 4040, + ServiceNameEnum.AUTOMATION .value : 5050, + ServiceNameEnum.POLICY .value : 6060, + ServiceNameEnum.MONITORING .value : 7070, + ServiceNameEnum.DLT .value : 8080, + ServiceNameEnum.COMPUTE .value : 9090, + ServiceNameEnum.OPTICALATTACKMANAGER .value : 10005, + ServiceNameEnum.OPTICALATTACKDETECTOR .value : 10006, + ServiceNameEnum.OPTICALATTACKMITIGATOR .value : 10007, + ServiceNameEnum.DBSCANSERVING .value : 10008, + ServiceNameEnum.INTERDOMAIN .value : 10010, + ServiceNameEnum.PATHCOMP .value : 10020, } # Default HTTP/REST-API service ports diff --git a/src/dbscanserving/Config.py b/src/dbscanserving/Config.py index d3140b29373b0110c8571440db6816534131c482..372416fad00d11e9bc01da3e697ddb6e3935ead5 100644 --- a/src/dbscanserving/Config.py +++ b/src/dbscanserving/Config.py @@ -19,7 +19,7 @@ LOG_LEVEL = logging.DEBUG # gRPC settings GRPC_SERVICE_PORT = 10008 -GRPC_MAX_WORKERS = 10 +GRPC_MAX_WORKERS = 10 GRPC_GRACE_PERIOD = 60 # Prometheus settings diff --git a/src/dbscanserving/__init__.py b/src/dbscanserving/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/dbscanserving/__init__.py +++ b/src/dbscanserving/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/dbscanserving/client/DbscanServingClient.py b/src/dbscanserving/client/DbscanServingClient.py index 3116e9fe3b65d4c5ba37b3ebbfc9f4819b97ace3..5069cab0b02bf8e27f21f32368db507da6af5908 100644 --- a/src/dbscanserving/client/DbscanServingClient.py +++ b/src/dbscanserving/client/DbscanServingClient.py @@ -12,46 +12,61 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from email.policy import default -import grpc, logging -from common.Settings import get_log_level, get_setting -from common.tools.client.RetryDecorator import retry, delay_exponential + +import grpc from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse from common.proto.dbscanserving_pb2_grpc import DetectorStub +from common.Settings import get_log_level, get_setting +from common.tools.client.RetryDecorator import delay_exponential, retry log_level = get_log_level() logging.basicConfig(level=log_level) LOGGER = logging.getLogger(__name__) MAX_RETRIES = 15 DELAY_FUNCTION = delay_exponential(initial=0.01, increment=2.0, maximum=5.0) -RETRY_DECORATOR = retry(max_retries=MAX_RETRIES, delay_function=DELAY_FUNCTION, prepare_method_name='connect') +RETRY_DECORATOR = retry( + max_retries=MAX_RETRIES, + delay_function=DELAY_FUNCTION, + prepare_method_name="connect", +) + class DbscanServingClient: def __init__(self, host=None, port=None): - if not host: host = get_setting('DBSCANSERVINGSERVICE_SERVICE_HOST') - if not port: port = get_setting('DBSCANSERVINGSERVICE_SERVICE_PORT_GRPC') - self.endpoint = '{:s}:{:s}'.format(str(host), str(port)) - LOGGER.debug('Creating channel to {:s}...'.format(str(self.endpoint))) + if not host: + host = get_setting("DBSCANSERVINGSERVICE_SERVICE_HOST") + if not port: + port = get_setting("DBSCANSERVINGSERVICE_SERVICE_PORT_GRPC") + self.endpoint = "{:s}:{:s}".format(str(host), str(port)) + LOGGER.debug("Creating channel to {:s}...".format(str(self.endpoint))) self.channel = None self.stub = None self.connect() - LOGGER.debug('Channel created') + LOGGER.debug("Channel created") def connect(self): self.channel = grpc.insecure_channel(self.endpoint) self.stub = DetectorStub(self.channel) def close(self): - if(self.channel is not None): self.channel.close() + if self.channel is not None: + self.channel.close() self.channel = None self.stub = None @RETRY_DECORATOR - def Detect(self, request : DetectionRequest) -> DetectionResponse: - LOGGER.debug('Detect request with {} samples and {} features'.format( - request.num_samples, - request.num_features - )) + def Detect(self, request: DetectionRequest) -> DetectionResponse: + LOGGER.debug( + "Detect request with {} samples and {} features".format( + request.num_samples, request.num_features + ) + ) response = self.stub.Detect(request) - LOGGER.debug('Detect result with {} cluster indices'.format(len(response.cluster_indices))) + LOGGER.debug( + "Detect result with {} cluster indices".format( + len(response.cluster_indices) + ) + ) return response diff --git a/src/dbscanserving/client/__init__.py b/src/dbscanserving/client/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/dbscanserving/client/__init__.py +++ b/src/dbscanserving/client/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/dbscanserving/service/DbscanService.py b/src/dbscanserving/service/DbscanService.py index b5d3edbe58db61587e27c871dca796754290c4bb..f91d4f8c5ea3a3858046567f4a781bc2fbd0a1d8 100644 --- a/src/dbscanserving/service/DbscanService.py +++ b/src/dbscanserving/service/DbscanService.py @@ -12,23 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import grpc import logging from concurrent import futures -from grpc_health.v1.health import HealthServicer, OVERALL_HEALTH -from grpc_health.v1.health_pb2 import HealthCheckResponse -from grpc_health.v1.health_pb2_grpc import add_HealthServicer_to_server + +import grpc from common.proto.dbscanserving_pb2_grpc import add_DetectorServicer_to_server +from dbscanserving.Config import GRPC_GRACE_PERIOD, GRPC_MAX_WORKERS, GRPC_SERVICE_PORT from dbscanserving.service.DbscanServiceServicerImpl import DbscanServiceServicerImpl -from dbscanserving.Config import GRPC_SERVICE_PORT, GRPC_MAX_WORKERS, GRPC_GRACE_PERIOD +from grpc_health.v1.health import OVERALL_HEALTH, HealthServicer +from grpc_health.v1.health_pb2 import HealthCheckResponse +from grpc_health.v1.health_pb2_grpc import add_HealthServicer_to_server -BIND_ADDRESS = '0.0.0.0' +BIND_ADDRESS = "0.0.0.0" LOGGER = logging.getLogger(__name__) + class DbscanService: def __init__( - self, address=BIND_ADDRESS, port=GRPC_SERVICE_PORT, max_workers=GRPC_MAX_WORKERS, - grace_period=GRPC_GRACE_PERIOD): + self, + address=BIND_ADDRESS, + port=GRPC_SERVICE_PORT, + max_workers=GRPC_MAX_WORKERS, + grace_period=GRPC_GRACE_PERIOD, + ): self.address = address self.port = port @@ -41,30 +47,41 @@ class DbscanService: self.server = None def start(self): - self.endpoint = '{:s}:{:s}'.format(str(self.address), str(self.port)) - LOGGER.debug('Starting Service (tentative endpoint: {:s}, max_workers: {:s})...'.format( - str(self.endpoint), str(self.max_workers))) + self.endpoint = "{:s}:{:s}".format(str(self.address), str(self.port)) + LOGGER.debug( + "Starting Service (tentative endpoint: {:s}, max_workers: {:s})...".format( + str(self.endpoint), str(self.max_workers) + ) + ) self.pool = futures.ThreadPoolExecutor(max_workers=self.max_workers) - self.server = grpc.server(self.pool) # , interceptors=(tracer_interceptor,)) + self.server = grpc.server(self.pool) # , interceptors=(tracer_interceptor,)) self.dbscan_servicer = DbscanServiceServicerImpl() add_DetectorServicer_to_server(self.dbscan_servicer, self.server) self.health_servicer = HealthServicer( - experimental_non_blocking=True, experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=1)) + experimental_non_blocking=True, + experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=1), + ) add_HealthServicer_to_server(self.health_servicer, self.server) port = self.server.add_insecure_port(self.endpoint) - self.endpoint = '{:s}:{:s}'.format(str(self.address), str(port)) - LOGGER.info('Listening on {:s}...'.format(self.endpoint)) + self.endpoint = "{:s}:{:s}".format(str(self.address), str(port)) + LOGGER.info("Listening on {:s}...".format(self.endpoint)) self.server.start() - self.health_servicer.set(OVERALL_HEALTH, HealthCheckResponse.SERVING) # pylint: disable=maybe-no-member + self.health_servicer.set( + OVERALL_HEALTH, HealthCheckResponse.SERVING + ) # pylint: disable=maybe-no-member - LOGGER.debug('Service started') + LOGGER.debug("Service started") def stop(self): - LOGGER.debug('Stopping service (grace period {:s} seconds)...'.format(str(self.grace_period))) + LOGGER.debug( + "Stopping service (grace period {:s} seconds)...".format( + str(self.grace_period) + ) + ) self.health_servicer.enter_graceful_shutdown() self.server.stop(self.grace_period) - LOGGER.debug('Service stopped') + LOGGER.debug("Service stopped") diff --git a/src/dbscanserving/service/DbscanServiceServicerImpl.py b/src/dbscanserving/service/DbscanServiceServicerImpl.py index 258f3a6194c6924720b1fe19be95871a029b6af7..d9035fc8b5c06279b52beb562a5c053ba0b05106 100644 --- a/src/dbscanserving/service/DbscanServiceServicerImpl.py +++ b/src/dbscanserving/service/DbscanServiceServicerImpl.py @@ -12,34 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os, grpc, logging -from sklearn.cluster import DBSCAN -from common.rpc_method_wrapper.Decorator import create_metrics, safe_and_metered_rpc_method +import logging +import os + +import grpc from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse from common.proto.dbscanserving_pb2_grpc import DetectorServicer +from common.rpc_method_wrapper.Decorator import ( + create_metrics, + safe_and_metered_rpc_method, +) +from sklearn.cluster import DBSCAN LOGGER = logging.getLogger(__name__) -SERVICE_NAME = 'DbscanServing' -METHOD_NAMES = ['Detect'] +SERVICE_NAME = "DbscanServing" +METHOD_NAMES = ["Detect"] METRICS = create_metrics(SERVICE_NAME, METHOD_NAMES) class DbscanServiceServicerImpl(DetectorServicer): - def __init__(self): - LOGGER.debug('Creating Servicer...') - LOGGER.debug('Servicer Created') + LOGGER.debug("Creating Servicer...") + LOGGER.debug("Servicer Created") @safe_and_metered_rpc_method(METRICS, LOGGER) - def Detect(self, request : DetectionRequest, context : grpc.ServicerContext) -> DetectionResponse: + def Detect( + self, request: DetectionRequest, context: grpc.ServicerContext + ) -> DetectionResponse: if request.num_samples != len(request.samples): - context.set_details("The sample dimension declared does not match with the number of samples received.") - LOGGER.debug(f"The sample dimension declared does not match with the number of samples received. Declared: {request.num_samples} - Received: {len(request.samples)}") + context.set_details( + "The sample dimension declared does not match with the number of samples received." + ) + LOGGER.debug( + f"The sample dimension declared does not match with the number of samples received. Declared: {request.num_samples} - Received: {len(request.samples)}" + ) context.set_code(grpc.StatusCode.INVALID_ARGUMENT) return DetectionResponse() # TODO: implement the validation of the features dimension - clusters = DBSCAN(eps=request.eps, min_samples=request.min_samples).fit_predict([[x for x in sample.features] for sample in request.samples]) + clusters = DBSCAN(eps=request.eps, min_samples=request.min_samples).fit_predict( + [[x for x in sample.features] for sample in request.samples] + ) response = DetectionResponse() for cluster in clusters: response.cluster_indices.append(cluster) diff --git a/src/dbscanserving/service/__init__.py b/src/dbscanserving/service/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/dbscanserving/service/__init__.py +++ b/src/dbscanserving/service/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/dbscanserving/service/__main__.py b/src/dbscanserving/service/__main__.py index 10a5dcaee7e68b3c62da4f9adf4493f5331a2a3a..b5ef26922a13b199cf3f5b84ba6bff2a0bd43593 100644 --- a/src/dbscanserving/service/__main__.py +++ b/src/dbscanserving/service/__main__.py @@ -12,35 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging, signal, sys, threading -from prometheus_client import start_http_server +import logging +import signal +import sys +import threading + from common.Settings import get_log_level, get_metrics_port, get_setting -from dbscanserving.Config import ( - GRPC_SERVICE_PORT, GRPC_MAX_WORKERS, GRPC_GRACE_PERIOD) +from dbscanserving.Config import GRPC_GRACE_PERIOD, GRPC_MAX_WORKERS, GRPC_SERVICE_PORT from dbscanserving.service.DbscanService import DbscanService +from prometheus_client import start_http_server terminate = threading.Event() LOGGER = None -def signal_handler(signal, frame): # pylint: disable=redefined-outer-name - LOGGER.warning('Terminate signal received') + +def signal_handler(signal, frame): # pylint: disable=redefined-outer-name + LOGGER.warning("Terminate signal received") terminate.set() + def main(): - global LOGGER # pylint: disable=global-statement + global LOGGER # pylint: disable=global-statement log_level = get_log_level() logging.basicConfig(level=log_level) LOGGER = logging.getLogger(__name__) - service_port = get_setting('DBSCANSERVICE_SERVICE_PORT_GRPC', default=GRPC_SERVICE_PORT) - max_workers = get_setting('MAX_WORKERS', default=GRPC_MAX_WORKERS ) - grace_period = get_setting('GRACE_PERIOD', default=GRPC_GRACE_PERIOD) + service_port = get_setting( + "DBSCANSERVICE_SERVICE_PORT_GRPC", default=GRPC_SERVICE_PORT + ) + max_workers = get_setting("MAX_WORKERS", default=GRPC_MAX_WORKERS) + grace_period = get_setting("GRACE_PERIOD", default=GRPC_GRACE_PERIOD) - signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - LOGGER.info('Starting...') + LOGGER.info("Starting...") # Start metrics server metrics_port = get_metrics_port() @@ -48,17 +55,20 @@ def main(): # Starting CentralizedCybersecurity service grpc_service = DbscanService( - port=service_port, max_workers=max_workers, grace_period=grace_period) + port=service_port, max_workers=max_workers, grace_period=grace_period + ) grpc_service.start() # Wait for Ctrl+C or termination signal - while not terminate.wait(timeout=0.1): pass + while not terminate.wait(timeout=0.1): + pass - LOGGER.info('Terminating...') + LOGGER.info("Terminating...") grpc_service.stop() - LOGGER.info('Bye') + LOGGER.info("Bye") return 0 -if __name__ == '__main__': + +if __name__ == "__main__": sys.exit(main()) diff --git a/src/dbscanserving/tests/__init__.py b/src/dbscanserving/tests/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/dbscanserving/tests/__init__.py +++ b/src/dbscanserving/tests/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/dbscanserving/tests/test_unitary.py b/src/dbscanserving/tests/test_unitary.py index 4156734e6e5f80dce40f59e88da2a4494bcb3f0f..82db7882c2d1fa37277a0bd374f2cc71ac745a36 100644 --- a/src/dbscanserving/tests/test_unitary.py +++ b/src/dbscanserving/tests/test_unitary.py @@ -12,32 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random, logging, pytest, numpy -from dbscanserving.Config import GRPC_SERVICE_PORT, GRPC_MAX_WORKERS, GRPC_GRACE_PERIOD +import logging +import random + +import numpy +import pytest +from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse, Sample from dbscanserving.client.DbscanServingClient import DbscanServingClient +from dbscanserving.Config import GRPC_GRACE_PERIOD, GRPC_MAX_WORKERS, GRPC_SERVICE_PORT from dbscanserving.service.DbscanService import DbscanService -from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse, Sample -port = 10000 + GRPC_SERVICE_PORT # avoid privileged ports +port = 10000 + GRPC_SERVICE_PORT # avoid privileged ports LOGGER = logging.getLogger(__name__) LOGGER.setLevel(logging.DEBUG) -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def dbscanserving_service(): _service = DbscanService( - port=port, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD) + port=port, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD + ) _service.start() yield _service _service.stop() -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def dbscanserving_client(): - _client = DbscanServingClient(address='127.0.0.1', port=port) + _client = DbscanServingClient(address="127.0.0.1", port=port) yield _client _client.close() -def test_detection_correct(dbscanserving_service, dbscanserving_client: DbscanServingClient): + +def test_detection_correct( + dbscanserving_service, dbscanserving_client: DbscanServingClient +): request: DetectionRequest = DetectionRequest() request.num_samples = 310 @@ -48,25 +58,28 @@ def test_detection_correct(dbscanserving_service, dbscanserving_client: DbscanSe for _ in range(200): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(0., 10.)) + grpc_sample.features.append(random.uniform(0.0, 10.0)) request.samples.append(grpc_sample) - + for _ in range(100): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(50., 60.)) + grpc_sample.features.append(random.uniform(50.0, 60.0)) request.samples.append(grpc_sample) - + for _ in range(10): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(5000., 6000.)) + grpc_sample.features.append(random.uniform(5000.0, 6000.0)) request.samples.append(grpc_sample) response: DetectionResponse = dbscanserving_client.Detect(request) assert len(response.cluster_indices) == 310 -def test_detection_incorrect(dbscanserving_service, dbscanserving_client: DbscanServingClient): + +def test_detection_incorrect( + dbscanserving_service, dbscanserving_client: DbscanServingClient +): request: DetectionRequest = DetectionRequest() request.num_samples = 210 @@ -77,25 +90,28 @@ def test_detection_incorrect(dbscanserving_service, dbscanserving_client: Dbscan for _ in range(200): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(0., 10.)) + grpc_sample.features.append(random.uniform(0.0, 10.0)) request.samples.append(grpc_sample) - + for _ in range(100): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(50., 60.)) + grpc_sample.features.append(random.uniform(50.0, 60.0)) request.samples.append(grpc_sample) - + for _ in range(10): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(5000., 6000.)) + grpc_sample.features.append(random.uniform(5000.0, 6000.0)) request.samples.append(grpc_sample) with pytest.raises(Exception): response: DetectionResponse = dbscanserving_client.Detect(request) -def test_detection_clusters(dbscanserving_service, dbscanserving_client: DbscanServingClient): + +def test_detection_clusters( + dbscanserving_service, dbscanserving_client: DbscanServingClient +): request: DetectionRequest = DetectionRequest() request.num_samples = 310 @@ -106,19 +122,19 @@ def test_detection_clusters(dbscanserving_service, dbscanserving_client: DbscanS for _ in range(200): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(0., 10.)) + grpc_sample.features.append(random.uniform(0.0, 10.0)) request.samples.append(grpc_sample) - + for _ in range(100): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(50., 60.)) + grpc_sample.features.append(random.uniform(50.0, 60.0)) request.samples.append(grpc_sample) - + for _ in range(10): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(5000., 6000.)) + grpc_sample.features.append(random.uniform(5000.0, 6000.0)) request.samples.append(grpc_sample) response: DetectionResponse = dbscanserving_client.Detect(request) diff --git a/src/opticalattackdetector/Dockerfile b/src/opticalattackdetector/Dockerfile index 0b5adc3c5e322c146b510d495852a972594f215a..ca745d85ad534fc816b20e2eb818301a1bde4d9f 100644 --- a/src/opticalattackdetector/Dockerfile +++ b/src/opticalattackdetector/Dockerfile @@ -77,8 +77,8 @@ RUN python3 -m pip install -r opticalattackdetector/requirements.txt # Add files into working directory COPY --chown=opticalattackdetector:opticalattackdetector src/context/. context -COPY --chown=opticalattackdetector:opticalattackdetector src/monitoring/. monitoring COPY --chown=opticalattackdetector:opticalattackdetector src/service/. service +COPY --chown=opticalattackdetector:opticalattackdetector src/monitoring/. monitoring COPY --chown=opticalattackdetector:opticalattackdetector src/dbscanserving/. dbscanserving COPY --chown=opticalattackdetector:opticalattackdetector src/opticalattackmitigator/. opticalattackmitigator COPY --chown=opticalattackdetector:opticalattackdetector src/opticalattackdetector/. opticalattackdetector diff --git a/src/opticalattackdetector/__init__.py b/src/opticalattackdetector/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/opticalattackdetector/__init__.py +++ b/src/opticalattackdetector/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/opticalattackdetector/client/OpticalAttackDetectorClient.py b/src/opticalattackdetector/client/OpticalAttackDetectorClient.py index 846bcce6e5a77d14ca1735a45b456e8c3865bc5b..2571efeaacf9147196b0409b144992ee217d9629 100644 --- a/src/opticalattackdetector/client/OpticalAttackDetectorClient.py +++ b/src/opticalattackdetector/client/OpticalAttackDetectorClient.py @@ -12,59 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. -import grpc, logging -from common.tools.client.RetryDecorator import retry, delay_exponential +import logging + +import grpc from common.proto.context_pb2 import Empty, Service from common.proto.monitoring_pb2 import KpiList -from common.proto.optical_attack_detector_pb2_grpc import OpticalAttackDetectorServiceStub +from common.proto.optical_attack_detector_pb2_grpc import ( + OpticalAttackDetectorServiceStub, +) +from common.tools.client.RetryDecorator import delay_exponential, retry +from common.tools.grpc.Tools import grpc_message_to_json LOGGER = logging.getLogger(__name__) MAX_RETRIES = 15 DELAY_FUNCTION = delay_exponential(initial=0.01, increment=2.0, maximum=5.0) -RETRY_DECORATOR = retry(max_retries=MAX_RETRIES, delay_function=DELAY_FUNCTION, prepare_method_name='connect') +RETRY_DECORATOR = retry( + max_retries=MAX_RETRIES, + delay_function=DELAY_FUNCTION, + prepare_method_name="connect", +) + class OpticalAttackDetectorClient: def __init__(self, address, port): - self.endpoint = '{:s}:{:s}'.format(str(address), str(port)) - LOGGER.debug('Creating channel to {:s}...'.format(str(self.endpoint))) + self.endpoint = "{:s}:{:s}".format(str(address), str(port)) + LOGGER.debug("Creating channel to {:s}...".format(str(self.endpoint))) self.channel = None self.stub = None self.connect() - LOGGER.debug('Channel created') + LOGGER.debug("Channel created") def connect(self): self.channel = grpc.insecure_channel(self.endpoint) self.stub = OpticalAttackDetectorServiceStub(self.channel) def close(self): - if(self.channel is not None): self.channel.close() + if self.channel is not None: + self.channel.close() self.channel = None self.stub = None @RETRY_DECORATOR - def NotifyServiceUpdate(self, request : Service) -> Empty: - LOGGER.debug('NotifyServiceUpdate request: {:s}'.format(str(request))) - response = self.stub.NotifyServiceUpdate(request) - LOGGER.debug('NotifyServiceUpdate result: {:s}'.format(str(response))) - return response - - @RETRY_DECORATOR - def DetectAttack(self, request : Empty) -> Empty: - LOGGER.debug('DetectAttack request: {:s}'.format(str(request))) + def DetectAttack(self, request: Empty) -> Empty: + LOGGER.debug( + "DetectAttack request: {:s}".format(str(grpc_message_to_json(request))) + ) response = self.stub.DetectAttack(request) - LOGGER.debug('DetectAttack result: {:s}'.format(str(response))) - return response - - @RETRY_DECORATOR - def ReportSummarizedKpi(self, request : KpiList) -> Empty: - LOGGER.debug('ReportSummarizedKpi request: {:s}'.format(str(request))) - response = self.stub.ReportSummarizedKpi(request) - LOGGER.debug('ReportSummarizedKpi result: {:s}'.format(str(response))) - return response - - @RETRY_DECORATOR - def ReportKpi(self, request : KpiList) -> Empty: - LOGGER.debug('ReportKpi request: {:s}'.format(str(request))) - response = self.stub.ReportKpi(request) - LOGGER.debug('ReportKpi result: {:s}'.format(str(response))) + LOGGER.debug( + "DetectAttack result: {:s}".format(str(grpc_message_to_json(response))) + ) return response diff --git a/src/opticalattackdetector/client/__init__.py b/src/opticalattackdetector/client/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/opticalattackdetector/client/__init__.py +++ b/src/opticalattackdetector/client/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/opticalattackdetector/service/OpticalAttackDetectorService.py b/src/opticalattackdetector/service/OpticalAttackDetectorService.py index 9bc1cc29758037b04ec2580fe94c0ec0c7e56f32..17fdaa1bd4c3cf14cfc1af6127cdd9c10f45c1fa 100644 --- a/src/opticalattackdetector/service/OpticalAttackDetectorService.py +++ b/src/opticalattackdetector/service/OpticalAttackDetectorService.py @@ -12,27 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import grpc import logging from concurrent import futures -from grpc_health.v1.health import HealthServicer, OVERALL_HEALTH +import grpc +from common.Constants import ( + DEFAULT_GRPC_BIND_ADDRESS, + DEFAULT_GRPC_GRACE_PERIOD, + DEFAULT_GRPC_MAX_WORKERS, +) +from common.proto.optical_attack_detector_pb2_grpc import ( + add_OpticalAttackDetectorServiceServicer_to_server, +) +from grpc_health.v1.health import OVERALL_HEALTH, HealthServicer from grpc_health.v1.health_pb2 import HealthCheckResponse from grpc_health.v1.health_pb2_grpc import add_HealthServicer_to_server - -from common.Constants import DEFAULT_GRPC_BIND_ADDRESS, DEFAULT_GRPC_MAX_WORKERS, DEFAULT_GRPC_GRACE_PERIOD -from common.proto.optical_attack_detector_pb2_grpc import ( - add_OpticalAttackDetectorServiceServicer_to_server) -from opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl import ( - OpticalAttackDetectorServiceServicerImpl) from opticalattackdetector.Config import GRPC_SERVICE_PORT +from opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl import ( + OpticalAttackDetectorServiceServicerImpl, +) LOGGER = logging.getLogger(__name__) + class OpticalAttackDetectorService: def __init__( - self, address=DEFAULT_GRPC_BIND_ADDRESS, port=GRPC_SERVICE_PORT, max_workers=DEFAULT_GRPC_MAX_WORKERS, - grace_period=DEFAULT_GRPC_GRACE_PERIOD): + self, + address=DEFAULT_GRPC_BIND_ADDRESS, + port=GRPC_SERVICE_PORT, + max_workers=DEFAULT_GRPC_MAX_WORKERS, + grace_period=DEFAULT_GRPC_GRACE_PERIOD, + ): self.address = address self.port = port @@ -45,30 +55,43 @@ class OpticalAttackDetectorService: self.server = None def start(self): - self.endpoint = '{:s}:{:s}'.format(str(self.address), str(self.port)) - LOGGER.debug('Starting Service (tentative endpoint: {:s}, max_workers: {:s})...'.format( - str(self.endpoint), str(self.max_workers))) + self.endpoint = "{:s}:{:s}".format(str(self.address), str(self.port)) + LOGGER.debug( + "Starting Service (tentative endpoint: {:s}, max_workers: {:s})...".format( + str(self.endpoint), str(self.max_workers) + ) + ) self.pool = futures.ThreadPoolExecutor(max_workers=self.max_workers) - self.server = grpc.server(self.pool) # , interceptors=(tracer_interceptor,)) + self.server = grpc.server(self.pool) # , interceptors=(tracer_interceptor,)) self.attack_detector_servicer = OpticalAttackDetectorServiceServicerImpl() - add_OpticalAttackDetectorServiceServicer_to_server(self.attack_detector_servicer, self.server) + add_OpticalAttackDetectorServiceServicer_to_server( + self.attack_detector_servicer, self.server + ) self.health_servicer = HealthServicer( - experimental_non_blocking=True, experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=1)) + experimental_non_blocking=True, + experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=1), + ) add_HealthServicer_to_server(self.health_servicer, self.server) port = self.server.add_insecure_port(self.endpoint) - self.endpoint = '{:s}:{:s}'.format(str(self.address), str(port)) - LOGGER.info('Listening on {:s}...'.format(self.endpoint)) + self.endpoint = "{:s}:{:s}".format(str(self.address), str(port)) + LOGGER.info("Listening on {:s}...".format(self.endpoint)) self.server.start() - self.health_servicer.set(OVERALL_HEALTH, HealthCheckResponse.SERVING) # pylint: disable=maybe-no-member + self.health_servicer.set( + OVERALL_HEALTH, HealthCheckResponse.SERVING + ) # pylint: disable=maybe-no-member - LOGGER.debug('Service started') + LOGGER.debug("Service started") def stop(self): - LOGGER.debug('Stopping service (grace period {:s} seconds)...'.format(str(self.grace_period))) + LOGGER.debug( + "Stopping service (grace period {:s} seconds)...".format( + str(self.grace_period) + ) + ) self.health_servicer.enter_graceful_shutdown() self.server.stop(self.grace_period) - LOGGER.debug('Service stopped') + LOGGER.debug("Service stopped") diff --git a/src/opticalattackdetector/service/OpticalAttackDetectorServiceServicerImpl.py b/src/opticalattackdetector/service/OpticalAttackDetectorServiceServicerImpl.py index d447859818a4f2d586b101bd4f39bd3965058a30..fd2ac85c38a9c9988f50f4a64579d9f83e85aabc 100644 --- a/src/opticalattackdetector/service/OpticalAttackDetectorServiceServicerImpl.py +++ b/src/opticalattackdetector/service/OpticalAttackDetectorServiceServicerImpl.py @@ -12,29 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os, grpc, logging, random +import logging +import random + +import grpc +from common.proto.context_pb2 import Empty, ServiceId +from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse, Sample +from common.proto.monitoring_pb2 import Kpi +from common.proto.optical_attack_detector_pb2_grpc import ( + OpticalAttackDetectorServiceServicer, +) +from common.proto.optical_attack_mitigator_pb2 import AttackDescription, AttackResponse +from common.rpc_method_wrapper.Decorator import ( + create_metrics, + safe_and_metered_rpc_method, +) from common.tools.timestamp.Converters import timestamp_utcnow_to_float -from common.rpc_method_wrapper.Decorator import create_metrics, safe_and_metered_rpc_method from context.client.ContextClient import ContextClient -from monitoring.client.MonitoringClient import MonitoringClient -from service.client.ServiceClient import ServiceClient -from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse, Sample from dbscanserving.client.DbscanServingClient import DbscanServingClient -from opticalattackmitigator.client.OpticalAttackMitigatorClient import OpticalAttackMitigatorClient -from common.proto.optical_attack_mitigator_pb2 import AttackDescription, AttackResponse -from common.proto.context_pb2 import (Empty, - Context, ContextId, ContextIdList, ContextList, - Service, ServiceId, ServiceIdList, ServiceList +from monitoring.client.MonitoringClient import MonitoringClient +from opticalattackmitigator.client.OpticalAttackMitigatorClient import ( + OpticalAttackMitigatorClient, ) -from common.proto.monitoring_pb2 import Kpi -from common.proto.optical_attack_detector_pb2_grpc import ( - OpticalAttackDetectorServiceServicer) - +from service.client.ServiceClient import ServiceClient LOGGER = logging.getLogger(__name__) -SERVICE_NAME = 'OpticalAttackDetector' -METHOD_NAMES = ['DetectAttack'] +SERVICE_NAME = "OpticalAttackDetector" +METHOD_NAMES = ["DetectAttack"] METRICS = create_metrics(SERVICE_NAME, METHOD_NAMES) context_client: ContextClient = ContextClient() @@ -45,17 +50,19 @@ attack_mitigator_client: OpticalAttackMitigatorClient = OpticalAttackMitigatorCl class OpticalAttackDetectorServiceServicerImpl(OpticalAttackDetectorServiceServicer): - def __init__(self): - LOGGER.debug('Creating Servicer...') - LOGGER.debug('Servicer Created') + LOGGER.debug("Creating Servicer...") + LOGGER.debug("Servicer Created") @safe_and_metered_rpc_method(METRICS, LOGGER) - def DetectAttack(self, service_id : ServiceId, context : grpc.ServicerContext) -> Empty: - LOGGER.debug('Received request for {}/{}...'.format( - service_id.context_id.context_uuid.uuid, - service_id.service_uuid.uuid - )) + def DetectAttack( + self, service_id: ServiceId, context: grpc.ServicerContext + ) -> Empty: + LOGGER.debug( + "Received request for {}/{}...".format( + service_id.context_id.context_uuid.uuid, service_id.service_uuid.uuid + ) + ) # run attack detection for every service request: DetectionRequest = DetectionRequest() request.num_samples = 310 @@ -65,25 +72,26 @@ class OpticalAttackDetectorServiceServicerImpl(OpticalAttackDetectorServiceServi for _ in range(200): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(0., 10.)) + grpc_sample.features.append(random.uniform(0.0, 10.0)) request.samples.append(grpc_sample) for _ in range(100): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(50., 60.)) + grpc_sample.features.append(random.uniform(50.0, 60.0)) request.samples.append(grpc_sample) for _ in range(10): grpc_sample = Sample() for __ in range(100): - grpc_sample.features.append(random.uniform(5000., 6000.)) + grpc_sample.features.append(random.uniform(5000.0, 6000.0)) request.samples.append(grpc_sample) response: DetectionResponse = dbscanserving_client.Detect(request) # including KPI + # TODO: set kpi_id and kpi_value according to the service kpi = Kpi() - kpi.kpi_id.kpi_id.uuid = "1" + kpi.kpi_id.kpi_id.uuid = random.choice(["1", "2"]) kpi.timestamp.timestamp = timestamp_utcnow_to_float() - kpi.kpi_value.int32Val = response.cluster_indices[-1] + kpi.kpi_value.int32Val = random.choice([-1, 0, 1]) # response.cluster_indices[-1] monitoring_client.IncludeKpi(kpi) if -1 in response.cluster_indices: # attack detected diff --git a/src/opticalattackdetector/service/__init__.py b/src/opticalattackdetector/service/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/opticalattackdetector/service/__init__.py +++ b/src/opticalattackdetector/service/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/opticalattackdetector/service/__main__.py b/src/opticalattackdetector/service/__main__.py index dac77750ca9087ce29467042f1c7dd08c1713313..cf481769486cbf4bad3712b7f2c7bf9218719480 100644 --- a/src/opticalattackdetector/service/__main__.py +++ b/src/opticalattackdetector/service/__main__.py @@ -12,28 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging, signal, sys, time, threading, random +import logging +import signal +import sys +import threading from multiprocessing import Manager, Process -from prometheus_client import start_http_server -import asyncio - -from common.Constants import DEFAULT_GRPC_MAX_WORKERS, DEFAULT_GRPC_GRACE_PERIOD -from common.Settings import get_log_level, get_metrics_port, get_setting -from common.tools.timestamp.Converters import timestamp_utcnow_to_float -from opticalattackdetector.Config import ( - GRPC_SERVICE_PORT, MONITORING_INTERVAL) -from common.proto.context_pb2 import (Empty, - Context, ContextId, ContextIdList, ContextList, - Service, ServiceId, ServiceIdList, ServiceList, Timestamp + +from common.Constants import ( + DEFAULT_GRPC_GRACE_PERIOD, + DEFAULT_GRPC_MAX_WORKERS, + ServiceNameEnum, +) +from common.Settings import ( + ENVVAR_SUFIX_SERVICE_HOST, + ENVVAR_SUFIX_SERVICE_PORT_GRPC, + get_env_var_name, + get_log_level, + get_metrics_port, + get_setting, + wait_for_environment_variables, +) +from opticalattackdetector.client.OpticalAttackDetectorClient import ( + OpticalAttackDetectorClient, ) -from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse, Sample -from common.proto.attack_mitigator_pb2 import AttackDescription, AttackResponse -from dbscanserving.client.DbscanServingClient import DbscanServingClient -from opticalattackmitigator.client.OpticalAttackMitigatorClient import OpticalAttackMitigatorClient -from opticalattackdetector.service.OpticalAttackDetectorService import OpticalAttackDetectorService -from opticalattackdetector.client.OpticalAttackDetectorClient import OpticalAttackDetectorClient -from monitoring.client.MonitoringClient import MonitoringClient -from common.proto.monitoring_pb2 import Kpi +from opticalattackdetector.Config import GRPC_SERVICE_PORT +from opticalattackdetector.service.OpticalAttackDetectorService import ( + OpticalAttackDetectorService, +) +from prometheus_client import start_http_server terminate = threading.Event() LOGGER = None @@ -41,26 +47,48 @@ LOGGER = None client: OpticalAttackDetectorClient = None -def signal_handler(signal, frame): # pylint: disable=redefined-outer-name - LOGGER.warning('Terminate signal received') +def signal_handler(signal, frame): # pylint: disable=redefined-outer-name + LOGGER.warning("Terminate signal received") terminate.set() def main(): - global LOGGER # pylint: disable=global-statement + global LOGGER # pylint: disable=global-statement log_level = get_log_level() logging.basicConfig(level=log_level) LOGGER = logging.getLogger(__name__) - service_port = get_setting('OPTICALATTACKDETECTORSERVICE_SERVICE_PORT_GRPC', default=GRPC_SERVICE_PORT) - max_workers = get_setting('MAX_WORKERS', default=DEFAULT_GRPC_MAX_WORKERS ) - grace_period = get_setting('GRACE_PERIOD', default=DEFAULT_GRPC_GRACE_PERIOD) - - signal.signal(signal.SIGINT, signal_handler) + wait_for_environment_variables( + [ + get_env_var_name(ServiceNameEnum.DBSCANSERVING, ENVVAR_SUFIX_SERVICE_HOST), + get_env_var_name( + ServiceNameEnum.DBSCANSERVING, ENVVAR_SUFIX_SERVICE_PORT_GRPC + ), + ] + ) + + wait_for_environment_variables( + [ + get_env_var_name( + ServiceNameEnum.OPTICALATTACKMITIGATOR, ENVVAR_SUFIX_SERVICE_HOST + ), + get_env_var_name( + ServiceNameEnum.OPTICALATTACKMITIGATOR, ENVVAR_SUFIX_SERVICE_PORT_GRPC + ), + ] + ) + + service_port = get_setting( + "OPTICALATTACKDETECTORSERVICE_SERVICE_PORT_GRPC", default=GRPC_SERVICE_PORT + ) + max_workers = get_setting("MAX_WORKERS", default=DEFAULT_GRPC_MAX_WORKERS) + grace_period = get_setting("GRACE_PERIOD", default=DEFAULT_GRPC_GRACE_PERIOD) + + signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - LOGGER.info('Starting...') + LOGGER.info("Starting...") # Start metrics server metrics_port = get_metrics_port() @@ -68,18 +96,21 @@ def main(): # Starting CentralizedCybersecurity service grpc_service = OpticalAttackDetectorService( - port=service_port, max_workers=max_workers, grace_period=grace_period) + port=service_port, max_workers=max_workers, grace_period=grace_period + ) grpc_service.start() # Wait for Ctrl+C or termination signal - while not terminate.wait(timeout=0.1): pass + while not terminate.wait(timeout=0.1): + pass - LOGGER.info('Terminating...') + LOGGER.info("Terminating...") grpc_service.stop() # p.kill() - LOGGER.info('Bye') + LOGGER.info("Bye") return 0 -if __name__ == '__main__': + +if __name__ == "__main__": sys.exit(main()) diff --git a/src/opticalattackdetector/tests/__init__.py b/src/opticalattackdetector/tests/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/opticalattackdetector/tests/__init__.py +++ b/src/opticalattackdetector/tests/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/opticalattackdetector/tests/example_objects.py b/src/opticalattackdetector/tests/example_objects.py index 642cba009e6dba5d7609c3941b0b43f3d20e27d4..21809839d3741bde1ae43841b78f04da984b3f9e 100644 --- a/src/opticalattackdetector/tests/example_objects.py +++ b/src/opticalattackdetector/tests/example_objects.py @@ -13,203 +13,239 @@ # limitations under the License. from copy import deepcopy + from common.Constants import DEFAULT_CONTEXT_UUID, DEFAULT_TOPOLOGY_UUID from common.proto.context_pb2 import ( - ConfigActionEnum, DeviceDriverEnum, DeviceOperationalStatusEnum, ServiceStatusEnum, ServiceTypeEnum) + ConfigActionEnum, + DeviceDriverEnum, + DeviceOperationalStatusEnum, + ServiceStatusEnum, + ServiceTypeEnum, +) # Some example objects to be used by the tests # Helper methods def config_rule(action, resource_value): - return {'action': action, 'resource_value': resource_value} + return {"action": action, "resource_value": resource_value} + def endpoint_id(topology_id, device_id, endpoint_uuid): - return {'topology_id': deepcopy(topology_id), 'device_id': deepcopy(device_id), - 'endpoint_uuid': {'uuid': endpoint_uuid}} + return { + "topology_id": deepcopy(topology_id), + "device_id": deepcopy(device_id), + "endpoint_uuid": {"uuid": endpoint_uuid}, + } + def endpoint(topology_id, device_id, endpoint_uuid, endpoint_type): - return {'endpoint_id': endpoint_id(topology_id, device_id, endpoint_uuid), 'endpoint_type': endpoint_type} + return { + "endpoint_id": endpoint_id(topology_id, device_id, endpoint_uuid), + "endpoint_type": endpoint_type, + } + ## use "deepcopy" to prevent propagating forced changes during tests -CONTEXT_ID = {'context_uuid': {'uuid': DEFAULT_CONTEXT_UUID}} +CONTEXT_ID = {"context_uuid": {"uuid": DEFAULT_CONTEXT_UUID}} CONTEXT = { - 'context_id': deepcopy(CONTEXT_ID), - 'topology_ids': [], - 'service_ids': [], + "context_id": deepcopy(CONTEXT_ID), + "topology_ids": [], + "service_ids": [], } -CONTEXT_ID_2 = {'context_uuid': {'uuid': 'test'}} +CONTEXT_ID_2 = {"context_uuid": {"uuid": "test"}} CONTEXT_2 = { - 'context_id': deepcopy(CONTEXT_ID_2), - 'topology_ids': [], - 'service_ids': [], + "context_id": deepcopy(CONTEXT_ID_2), + "topology_ids": [], + "service_ids": [], } TOPOLOGY_ID = { - 'context_id': deepcopy(CONTEXT_ID), - 'topology_uuid': {'uuid': DEFAULT_TOPOLOGY_UUID}, + "context_id": deepcopy(CONTEXT_ID), + "topology_uuid": {"uuid": DEFAULT_TOPOLOGY_UUID}, } TOPOLOGY = { - 'topology_id': deepcopy(TOPOLOGY_ID), - 'device_ids': [], - 'link_ids': [], + "topology_id": deepcopy(TOPOLOGY_ID), + "device_ids": [], + "link_ids": [], } -DEVICE1_UUID = 'DEV1' -DEVICE1_ID = {'device_uuid': {'uuid': DEVICE1_UUID}} +DEVICE1_UUID = "DEV1" +DEVICE1_ID = {"device_uuid": {"uuid": DEVICE1_UUID}} DEVICE1 = { - 'device_id': deepcopy(DEVICE1_ID), - 'device_type': 'packet-router', - 'device_config': {'config_rules': [ - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'value1'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'value2'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'value3'), - ]}, - 'device_operational_status': DeviceOperationalStatusEnum.DEVICEOPERATIONALSTATUS_ENABLED, - 'device_drivers': [DeviceDriverEnum.DEVICEDRIVER_OPENCONFIG, DeviceDriverEnum.DEVICEDRIVER_P4], - 'device_endpoints': [ - endpoint(TOPOLOGY_ID, DEVICE1_ID, 'EP2', 'port-packet-100G'), - endpoint(TOPOLOGY_ID, DEVICE1_ID, 'EP3', 'port-packet-100G'), - endpoint(TOPOLOGY_ID, DEVICE1_ID, 'EP100', 'port-packet-10G'), + "device_id": deepcopy(DEVICE1_ID), + "device_type": "packet-router", + "device_config": { + "config_rules": [ + config_rule(ConfigActionEnum.CONFIGACTION_SET, "value1"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "value2"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "value3"), + ] + }, + "device_operational_status": DeviceOperationalStatusEnum.DEVICEOPERATIONALSTATUS_ENABLED, + "device_drivers": [ + DeviceDriverEnum.DEVICEDRIVER_OPENCONFIG, + DeviceDriverEnum.DEVICEDRIVER_P4, + ], + "device_endpoints": [ + endpoint(TOPOLOGY_ID, DEVICE1_ID, "EP2", "port-packet-100G"), + endpoint(TOPOLOGY_ID, DEVICE1_ID, "EP3", "port-packet-100G"), + endpoint(TOPOLOGY_ID, DEVICE1_ID, "EP100", "port-packet-10G"), ], } -DEVICE2_UUID = 'DEV2' -DEVICE2_ID = {'device_uuid': {'uuid': DEVICE2_UUID}} +DEVICE2_UUID = "DEV2" +DEVICE2_ID = {"device_uuid": {"uuid": DEVICE2_UUID}} DEVICE2 = { - 'device_id': deepcopy(DEVICE2_ID), - 'device_type': 'packet-router', - 'device_config': {'config_rules': [ - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'dev/rsrc1/value', 'value4'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'dev/rsrc2/value', 'value5'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'dev/rsrc3/value', 'value6'), - ]}, - 'device_operational_status': DeviceOperationalStatusEnum.DEVICEOPERATIONALSTATUS_ENABLED, - 'device_drivers': [DeviceDriverEnum.DEVICEDRIVER_OPENCONFIG, DeviceDriverEnum.DEVICEDRIVER_P4], - 'device_endpoints': [ - endpoint(TOPOLOGY_ID, DEVICE2_ID, 'EP1', 'port-packet-100G'), - endpoint(TOPOLOGY_ID, DEVICE2_ID, 'EP3', 'port-packet-100G'), - endpoint(TOPOLOGY_ID, DEVICE2_ID, 'EP100', 'port-packet-10G'), + "device_id": deepcopy(DEVICE2_ID), + "device_type": "packet-router", + "device_config": { + "config_rules": [ + config_rule(ConfigActionEnum.CONFIGACTION_SET, "dev/rsrc1/value", "value4"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "dev/rsrc2/value", "value5"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "dev/rsrc3/value", "value6"), + ] + }, + "device_operational_status": DeviceOperationalStatusEnum.DEVICEOPERATIONALSTATUS_ENABLED, + "device_drivers": [ + DeviceDriverEnum.DEVICEDRIVER_OPENCONFIG, + DeviceDriverEnum.DEVICEDRIVER_P4, + ], + "device_endpoints": [ + endpoint(TOPOLOGY_ID, DEVICE2_ID, "EP1", "port-packet-100G"), + endpoint(TOPOLOGY_ID, DEVICE2_ID, "EP3", "port-packet-100G"), + endpoint(TOPOLOGY_ID, DEVICE2_ID, "EP100", "port-packet-10G"), ], } -DEVICE3_UUID = 'DEV3' -DEVICE3_ID = {'device_uuid': {'uuid': DEVICE3_UUID}} +DEVICE3_UUID = "DEV3" +DEVICE3_ID = {"device_uuid": {"uuid": DEVICE3_UUID}} DEVICE3 = { - 'device_id': deepcopy(DEVICE3_ID), - 'device_type': 'packet-router', - 'device_config': {'config_rules': [ - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'dev/rsrc1/value', 'value4'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'dev/rsrc2/value', 'value5'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'dev/rsrc3/value', 'value6'), - ]}, - 'device_operational_status': DeviceOperationalStatusEnum.DEVICEOPERATIONALSTATUS_ENABLED, - 'device_drivers': [DeviceDriverEnum.DEVICEDRIVER_OPENCONFIG, DeviceDriverEnum.DEVICEDRIVER_P4], - 'device_endpoints': [ - endpoint(TOPOLOGY_ID, DEVICE3_ID, 'EP1', 'port-packet-100G'), - endpoint(TOPOLOGY_ID, DEVICE3_ID, 'EP2', 'port-packet-100G'), - endpoint(TOPOLOGY_ID, DEVICE3_ID, 'EP100', 'port-packet-10G'), + "device_id": deepcopy(DEVICE3_ID), + "device_type": "packet-router", + "device_config": { + "config_rules": [ + config_rule(ConfigActionEnum.CONFIGACTION_SET, "dev/rsrc1/value", "value4"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "dev/rsrc2/value", "value5"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "dev/rsrc3/value", "value6"), + ] + }, + "device_operational_status": DeviceOperationalStatusEnum.DEVICEOPERATIONALSTATUS_ENABLED, + "device_drivers": [ + DeviceDriverEnum.DEVICEDRIVER_OPENCONFIG, + DeviceDriverEnum.DEVICEDRIVER_P4, + ], + "device_endpoints": [ + endpoint(TOPOLOGY_ID, DEVICE3_ID, "EP1", "port-packet-100G"), + endpoint(TOPOLOGY_ID, DEVICE3_ID, "EP2", "port-packet-100G"), + endpoint(TOPOLOGY_ID, DEVICE3_ID, "EP100", "port-packet-10G"), ], } -LINK_DEV1_DEV2_UUID = 'DEV1/EP2 ==> DEV2/EP1' -LINK_DEV1_DEV2_ID = {'link_uuid': {'uuid': LINK_DEV1_DEV2_UUID}} +LINK_DEV1_DEV2_UUID = "DEV1/EP2 ==> DEV2/EP1" +LINK_DEV1_DEV2_ID = {"link_uuid": {"uuid": LINK_DEV1_DEV2_UUID}} LINK_DEV1_DEV2 = { - 'link_id': deepcopy(LINK_DEV1_DEV2_ID), - 'link_endpoint_ids' : [ - endpoint_id(TOPOLOGY_ID, DEVICE1_ID, 'EP2'), - endpoint_id(TOPOLOGY_ID, DEVICE2_ID, 'EP1'), - ] + "link_id": deepcopy(LINK_DEV1_DEV2_ID), + "link_endpoint_ids": [ + endpoint_id(TOPOLOGY_ID, DEVICE1_ID, "EP2"), + endpoint_id(TOPOLOGY_ID, DEVICE2_ID, "EP1"), + ], } -LINK_DEV2_DEV3_UUID = 'DEV2/EP3 ==> DEV3/EP2' -LINK_DEV2_DEV3_ID = {'link_uuid': {'uuid': LINK_DEV2_DEV3_UUID}} +LINK_DEV2_DEV3_UUID = "DEV2/EP3 ==> DEV3/EP2" +LINK_DEV2_DEV3_ID = {"link_uuid": {"uuid": LINK_DEV2_DEV3_UUID}} LINK_DEV2_DEV3 = { - 'link_id': deepcopy(LINK_DEV2_DEV3_ID), - 'link_endpoint_ids' : [ - endpoint_id(TOPOLOGY_ID, DEVICE2_ID, 'EP3'), - endpoint_id(TOPOLOGY_ID, DEVICE3_ID, 'EP2'), - ] + "link_id": deepcopy(LINK_DEV2_DEV3_ID), + "link_endpoint_ids": [ + endpoint_id(TOPOLOGY_ID, DEVICE2_ID, "EP3"), + endpoint_id(TOPOLOGY_ID, DEVICE3_ID, "EP2"), + ], } -LINK_DEV1_DEV3_UUID = 'DEV1/EP3 ==> DEV3/EP1' -LINK_DEV1_DEV3_ID = {'link_uuid': {'uuid': LINK_DEV1_DEV3_UUID}} +LINK_DEV1_DEV3_UUID = "DEV1/EP3 ==> DEV3/EP1" +LINK_DEV1_DEV3_ID = {"link_uuid": {"uuid": LINK_DEV1_DEV3_UUID}} LINK_DEV1_DEV3 = { - 'link_id': deepcopy(LINK_DEV1_DEV3_ID), - 'link_endpoint_ids' : [ - endpoint_id(TOPOLOGY_ID, DEVICE1_ID, 'EP3'), - endpoint_id(TOPOLOGY_ID, DEVICE3_ID, 'EP1'), - ] + "link_id": deepcopy(LINK_DEV1_DEV3_ID), + "link_endpoint_ids": [ + endpoint_id(TOPOLOGY_ID, DEVICE1_ID, "EP3"), + endpoint_id(TOPOLOGY_ID, DEVICE3_ID, "EP1"), + ], } -SERVICE_DEV1_DEV2_UUID = 'SVC:DEV1/EP100-DEV2/EP100' +SERVICE_DEV1_DEV2_UUID = "SVC:DEV1/EP100-DEV2/EP100" SERVICE_DEV1_DEV2_ID = { - 'context_id': deepcopy(CONTEXT_ID), - 'service_uuid': {'uuid': SERVICE_DEV1_DEV2_UUID}, + "context_id": deepcopy(CONTEXT_ID), + "service_uuid": {"uuid": SERVICE_DEV1_DEV2_UUID}, } SERVICE_DEV1_DEV2 = { - 'service_id': deepcopy(SERVICE_DEV1_DEV2_ID), - 'service_type': ServiceTypeEnum.SERVICETYPE_L3NM, - 'service_endpoint_ids' : [ - endpoint_id(TOPOLOGY_ID, DEVICE1_ID, 'EP100'), - endpoint_id(TOPOLOGY_ID, DEVICE2_ID, 'EP100'), + "service_id": deepcopy(SERVICE_DEV1_DEV2_ID), + "service_type": ServiceTypeEnum.SERVICETYPE_L3NM, + "service_endpoint_ids": [ + endpoint_id(TOPOLOGY_ID, DEVICE1_ID, "EP100"), + endpoint_id(TOPOLOGY_ID, DEVICE2_ID, "EP100"), ], # 'service_constraints': [ # {'constraint_type': 'latency_ms', 'constraint_value': '15.2'}, # {'constraint_type': 'jitter_us', 'constraint_value': '1.2'}, # ], - 'service_status': {'service_status': ServiceStatusEnum.SERVICESTATUS_ACTIVE}, - 'service_config': {'config_rules': [ - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'svc/rsrc1/value', 'value7'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'svc/rsrc2/value', 'value8'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'svc/rsrc3/value', 'value9'), - ]}, + "service_status": {"service_status": ServiceStatusEnum.SERVICESTATUS_ACTIVE}, + "service_config": { + "config_rules": [ + config_rule(ConfigActionEnum.CONFIGACTION_SET, "svc/rsrc1/value", "value7"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "svc/rsrc2/value", "value8"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "svc/rsrc3/value", "value9"), + ] + }, } -SERVICE_DEV1_DEV3_UUID = 'SVC:DEV1/EP100-DEV3/EP100' +SERVICE_DEV1_DEV3_UUID = "SVC:DEV1/EP100-DEV3/EP100" SERVICE_DEV1_DEV3_ID = { - 'context_id': deepcopy(CONTEXT_ID), - 'service_uuid': {'uuid': SERVICE_DEV1_DEV3_UUID}, + "context_id": deepcopy(CONTEXT_ID), + "service_uuid": {"uuid": SERVICE_DEV1_DEV3_UUID}, } SERVICE_DEV1_DEV3 = { - 'service_id': deepcopy(SERVICE_DEV1_DEV3_ID), - 'service_type': ServiceTypeEnum.SERVICETYPE_L3NM, - 'service_endpoint_ids' : [ - endpoint_id(TOPOLOGY_ID, DEVICE1_ID, 'EP100'), - endpoint_id(TOPOLOGY_ID, DEVICE3_ID, 'EP100'), + "service_id": deepcopy(SERVICE_DEV1_DEV3_ID), + "service_type": ServiceTypeEnum.SERVICETYPE_L3NM, + "service_endpoint_ids": [ + endpoint_id(TOPOLOGY_ID, DEVICE1_ID, "EP100"), + endpoint_id(TOPOLOGY_ID, DEVICE3_ID, "EP100"), ], # 'service_constraints': [ # {'constraint_type': 'latency_ms', 'constraint_value': '5.8'}, # {'constraint_type': 'jitter_us', 'constraint_value': '0.1'}, # ], - 'service_status': {'service_status': ServiceStatusEnum.SERVICESTATUS_ACTIVE}, - 'service_config': {'config_rules': [ - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'svc/rsrc1/value', 'value7'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'svc/rsrc2/value', 'value8'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'svc/rsrc3/value', 'value9'), - ]}, + "service_status": {"service_status": ServiceStatusEnum.SERVICESTATUS_ACTIVE}, + "service_config": { + "config_rules": [ + config_rule(ConfigActionEnum.CONFIGACTION_SET, "svc/rsrc1/value", "value7"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "svc/rsrc2/value", "value8"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "svc/rsrc3/value", "value9"), + ] + }, } -SERVICE_DEV2_DEV3_UUID = 'SVC:DEV2/EP100-DEV3/EP100' +SERVICE_DEV2_DEV3_UUID = "SVC:DEV2/EP100-DEV3/EP100" SERVICE_DEV2_DEV3_ID = { - 'context_id': deepcopy(CONTEXT_ID), - 'service_uuid': {'uuid': SERVICE_DEV2_DEV3_UUID}, + "context_id": deepcopy(CONTEXT_ID), + "service_uuid": {"uuid": SERVICE_DEV2_DEV3_UUID}, } SERVICE_DEV2_DEV3 = { - 'service_id': deepcopy(SERVICE_DEV2_DEV3_ID), - 'service_type': ServiceTypeEnum.SERVICETYPE_L3NM, - 'service_endpoint_ids' : [ - endpoint_id(TOPOLOGY_ID, DEVICE2_ID, 'EP100'), - endpoint_id(TOPOLOGY_ID, DEVICE3_ID, 'EP100'), + "service_id": deepcopy(SERVICE_DEV2_DEV3_ID), + "service_type": ServiceTypeEnum.SERVICETYPE_L3NM, + "service_endpoint_ids": [ + endpoint_id(TOPOLOGY_ID, DEVICE2_ID, "EP100"), + endpoint_id(TOPOLOGY_ID, DEVICE3_ID, "EP100"), ], # 'service_constraints': [ # {'constraint_type': 'latency_ms', 'constraint_value': '23.1'}, # {'constraint_type': 'jitter_us', 'constraint_value': '3.4'}, # ], - 'service_status': {'service_status': ServiceStatusEnum.SERVICESTATUS_ACTIVE}, - 'service_config': {'config_rules': [ - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'svc/rsrc1/value', 'value7'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'svc/rsrc2/value', 'value8'), - config_rule(ConfigActionEnum.CONFIGACTION_SET, 'svc/rsrc3/value', 'value9'), - ]}, + "service_status": {"service_status": ServiceStatusEnum.SERVICESTATUS_ACTIVE}, + "service_config": { + "config_rules": [ + config_rule(ConfigActionEnum.CONFIGACTION_SET, "svc/rsrc1/value", "value7"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "svc/rsrc2/value", "value8"), + config_rule(ConfigActionEnum.CONFIGACTION_SET, "svc/rsrc3/value", "value9"), + ] + }, } diff --git a/src/opticalattackdetector/tests/test_unitary.py b/src/opticalattackdetector/tests/test_unitary.py index 5aadbe9b177f68dcceca32fe3c4ac57c4fe00163..1c55563d8452956b008798dfa956ec130e50b0b0 100644 --- a/src/opticalattackdetector/tests/test_unitary.py +++ b/src/opticalattackdetector/tests/test_unitary.py @@ -12,27 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging, pytest +import logging from unittest.mock import patch -from common.Constants import DEFAULT_GRPC_BIND_ADDRESS, DEFAULT_GRPC_MAX_WORKERS, DEFAULT_GRPC_GRACE_PERIOD -from common.proto.context_pb2 import ContextIdList, ContextId, Empty, Service, ContextId, ServiceList +import pytest +from common.Constants import ( + DEFAULT_GRPC_BIND_ADDRESS, + DEFAULT_GRPC_GRACE_PERIOD, + DEFAULT_GRPC_MAX_WORKERS, +) +from common.proto.context_pb2 import ( + ContextId, + ContextIdList, + Empty, + Service, + ServiceList, +) from common.proto.monitoring_pb2 import Kpi, KpiList - +from opticalattackdetector.client.OpticalAttackDetectorClient import ( + OpticalAttackDetectorClient, +) from opticalattackdetector.Config import GRPC_SERVICE_PORT -from opticalattackdetector.client.OpticalAttackDetectorClient import OpticalAttackDetectorClient -from opticalattackdetector.service.OpticalAttackDetectorService import OpticalAttackDetectorService +from opticalattackdetector.service.OpticalAttackDetectorService import ( + OpticalAttackDetectorService, +) + from .example_objects import CONTEXT_ID, CONTEXT_ID_2, SERVICE_DEV1_DEV2 -port = 10000 + GRPC_SERVICE_PORT # avoid privileged ports +port = 10000 + GRPC_SERVICE_PORT # avoid privileged ports LOGGER = logging.getLogger(__name__) LOGGER.setLevel(logging.DEBUG) -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def optical_attack_detector_service(): _service = OpticalAttackDetectorService( - port=port, max_workers=DEFAULT_GRPC_MAX_WORKERS, grace_period=DEFAULT_GRPC_GRACE_PERIOD) + port=port, + max_workers=DEFAULT_GRPC_MAX_WORKERS, + grace_period=DEFAULT_GRPC_GRACE_PERIOD, + ) # mocker_context_client = mock.patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client') # mocker_context_client.start() @@ -45,28 +64,44 @@ def optical_attack_detector_service(): # mocker_context_client.stop() # mocker_influx_db.stop() -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def optical_attack_detector_client(optical_attack_detector_service): - _client = OpticalAttackDetectorClient(address='127.0.0.1', port=port) + _client = OpticalAttackDetectorClient(address="127.0.0.1", port=port) yield _client _client.close() -def test_notify_service_update(optical_attack_detector_client: OpticalAttackDetectorClient): + +def test_notify_service_update( + optical_attack_detector_client: OpticalAttackDetectorClient, +): service = Service() optical_attack_detector_client.NotifyServiceUpdate(service) -def test_detect_attack_no_contexts(optical_attack_detector_client: OpticalAttackDetectorClient): - with patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client') as context, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client') as influxdb: + +def test_detect_attack_no_contexts( + optical_attack_detector_client: OpticalAttackDetectorClient, +): + with patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client" + ) as context, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client" + ) as influxdb: request = Empty() optical_attack_detector_client.DetectAttack(request) context.ListContextIds.assert_called_once() influxdb.query.assert_called_once() context.ListServices.assert_not_called() -def test_detect_attack_with_context(optical_attack_detector_client: OpticalAttackDetectorClient,): - with patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client') as context, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client') as influxdb: + +def test_detect_attack_with_context( + optical_attack_detector_client: OpticalAttackDetectorClient, +): + with patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client" + ) as context, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client" + ) as influxdb: # setting up the mock cid_list = ContextIdList() cid_list.context_ids.append(ContextId(**CONTEXT_ID)) @@ -81,9 +116,15 @@ def test_detect_attack_with_context(optical_attack_detector_client: OpticalAttac context.ListServices.assert_called_with(cid_list.context_ids[0]) influxdb.query.assert_called_once() -def test_detect_attack_with_contexts(optical_attack_detector_client: OpticalAttackDetectorClient,): - with patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client') as context, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client') as influxdb: + +def test_detect_attack_with_contexts( + optical_attack_detector_client: OpticalAttackDetectorClient, +): + with patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client" + ) as context, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client" + ) as influxdb: # setting up the mock cid_list = ContextIdList() cid_list.context_ids.append(ContextId(**CONTEXT_ID)) @@ -100,10 +141,17 @@ def test_detect_attack_with_contexts(optical_attack_detector_client: OpticalAtta context.ListServices.assert_any_call(cid_list.context_ids[1]) influxdb.query.assert_called_once() -def test_detect_attack_with_service(optical_attack_detector_client: OpticalAttackDetectorClient,): - with patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client') as context, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client') as influxdb, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.dbscanserving_client') as dbscan: + +def test_detect_attack_with_service( + optical_attack_detector_client: OpticalAttackDetectorClient, +): + with patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client" + ) as context, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client" + ) as influxdb, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.dbscanserving_client" + ) as dbscan: # setting up the mock cid_list = ContextIdList() @@ -126,11 +174,19 @@ def test_detect_attack_with_service(optical_attack_detector_client: OpticalAttac influxdb.query.assert_called_once() dbscan.Detect.assert_called() -def test_detect_attack_no_attack(optical_attack_detector_client: OpticalAttackDetectorClient,): - with patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client') as context, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client') as influxdb, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.dbscanserving_client') as dbscan, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.attack_mitigator_client') as mitigator: + +def test_detect_attack_no_attack( + optical_attack_detector_client: OpticalAttackDetectorClient, +): + with patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client" + ) as context, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client" + ) as influxdb, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.dbscanserving_client" + ) as dbscan, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.attack_mitigator_client" + ) as mitigator: # setting up the mock cid_list = ContextIdList() @@ -155,11 +211,19 @@ def test_detect_attack_no_attack(optical_attack_detector_client: OpticalAttackDe dbscan.Detect.assert_called() mitigator.NotifyAttack.assert_not_called() -def test_detect_attack_with_attack(optical_attack_detector_client: OpticalAttackDetectorClient,): - with patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client') as context, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client') as influxdb, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.dbscanserving_client') as dbscan, \ - patch('opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.attack_mitigator_client') as mitigator: + +def test_detect_attack_with_attack( + optical_attack_detector_client: OpticalAttackDetectorClient, +): + with patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.context_client" + ) as context, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.influxdb_client" + ) as influxdb, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.dbscanserving_client" + ) as dbscan, patch( + "opticalattackdetector.service.OpticalAttackDetectorServiceServicerImpl.attack_mitigator_client" + ) as mitigator: # setting up the mock cid_list = ContextIdList() @@ -184,10 +248,14 @@ def test_detect_attack_with_attack(optical_attack_detector_client: OpticalAttack dbscan.Detect.assert_called() mitigator.NotifyAttack.assert_called() -def test_report_summarized_kpi(optical_attack_detector_client: OpticalAttackDetectorClient): + +def test_report_summarized_kpi( + optical_attack_detector_client: OpticalAttackDetectorClient, +): kpi_list = KpiList() optical_attack_detector_client.ReportSummarizedKpi(kpi_list) + def test_report_kpi(optical_attack_detector_client: OpticalAttackDetectorClient): kpi_list = KpiList() optical_attack_detector_client.ReportKpi(kpi_list) diff --git a/src/opticalattackmanager/Dockerfile b/src/opticalattackmanager/Dockerfile index 908c4893f5178fdb01bca48d54f7641a91ca08c6..a6aedb8049d46905ba0dbf8792dd51c58e1940bf 100644 --- a/src/opticalattackmanager/Dockerfile +++ b/src/opticalattackmanager/Dockerfile @@ -52,19 +52,30 @@ RUN python3 -m pip install --upgrade pip-tools # Note: this step enables sharing the previous Docker build steps among all the Python components WORKDIR /home/opticalattackmanager/teraflow COPY --chown=opticalattackmanager:opticalattackmanager common_requirements.in common_requirements.in -RUN pip-compile --quiet --output-file=common_requirements.txt common_requirements.in +COPY --chown=opticalattackmanager:opticalattackmanager src/opticalattackmanager/requirements.in opticalattackmanager/requirements.in +RUN sed -i '/protobuf/d' common_requirements.in && sed -i '/grpc/d' common_requirements.in +RUN pip-compile --output-file=common_requirements.txt common_requirements.in opticalattackmanager/requirements.in RUN python3 -m pip install -r common_requirements.txt +# Get Python packages per module +# COPY --chown=opticalattackmanager:opticalattackmanager src/opticalattackmanager/requirements.in opticalattackmanager/requirements.in +# RUN pip-compile --quiet --output-file=opticalattackmanager/requirements.txt opticalattackmanager/requirements.in +# RUN python3 -m pip install -r opticalattackmanager/requirements.txt + # Add common files into working directory WORKDIR /home/opticalattackmanager/teraflow/common COPY --chown=opticalattackmanager:opticalattackmanager src/common/. ./ # Create proto sub-folder, copy .proto files, and generate Python code RUN mkdir -p /home/opticalattackmanager/teraflow/common/proto +RUN mkdir -p /home/opticalattackmanager/teraflow/common/proto/asyncio WORKDIR /home/opticalattackmanager/teraflow/common/proto RUN touch __init__.py +RUN touch asyncio/__init__.py COPY --chown=opticalattackmanager:opticalattackmanager proto/*.proto ./ RUN python3 -m grpc_tools.protoc -I=. --python_out=. --grpc_python_out=. *.proto +# new line added to generate protobuf for the `grpclib` library +RUN python3 -m grpc_tools.protoc -I=./ --python_out=./asyncio --grpclib_python_out=./asyncio *.proto RUN rm *.proto RUN find . -type f -exec sed -i -E 's/(import\ .*)_pb2/from . \1_pb2/g' {} \; @@ -72,17 +83,10 @@ RUN find . -type f -exec sed -i -E 's/(import\ .*)_pb2/from . \1_pb2/g' {} \; RUN mkdir -p /home/opticalattackmanager/teraflow/opticalattackmanager WORKDIR /home/opticalattackmanager/teraflow -# Get Python packages per module -COPY --chown=opticalattackmanager:opticalattackmanager src/opticalattackmanager/requirements.in opticalattackmanager/requirements.in -RUN pip-compile --quiet --output-file=opticalattackmanager/requirements.txt opticalattackmanager/requirements.in -RUN python3 -m pip install -r opticalattackmanager/requirements.txt - # Add files into working directory COPY --chown=opticalattackmanager:opticalattackmanager src/context/. context COPY --chown=opticalattackmanager:opticalattackmanager src/monitoring/. monitoring -COPY --chown=opticalattackmanager:opticalattackmanager src/service/. service -COPY --chown=opticalattackmanager:opticalattackmanager src/dbscanserving/. dbscanserving -COPY --chown=opticalattackmanager:opticalattackmanager src/opticalattackmitigator/. opticalattackmitigator +COPY --chown=opticalattackmanager:opticalattackmanager src/opticalattackdetector/. opticalattackdetector COPY --chown=opticalattackmanager:opticalattackmanager src/opticalattackmanager/. opticalattackmanager # Start opticalattackmanager service diff --git a/src/opticalattackmanager/requirements.in b/src/opticalattackmanager/requirements.in index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..39d4f655900f43a9bd1d8196eececc85119c6f1e 100644 --- a/src/opticalattackmanager/requirements.in +++ b/src/opticalattackmanager/requirements.in @@ -0,0 +1,4 @@ +grpcio==1.49.* +grpcio-health-checking==1.49.* +grpcio-tools==1.49.* +grpclib[protobuf] \ No newline at end of file diff --git a/src/opticalattackmanager/send_task.py b/src/opticalattackmanager/send_task.py index 2b41dcc4f85053e2fe55238c660a1a368efd6814..101f82e67ce5237b584c6af301b64933b5835618 100644 --- a/src/opticalattackmanager/send_task.py +++ b/src/opticalattackmanager/send_task.py @@ -1,23 +1,27 @@ import asyncio import logging +import random import grpc -import random -from common.Settings import get_log_level, get_setting from common.proto.dbscanserving_pb2 import DetectionRequest, DetectionResponse, Sample from common.proto.dbscanserving_pb2_grpc import DetectorStub +from common.Settings import get_log_level, get_setting # For more channel options, please see https://grpc.io/grpc/core/group__grpc__arg__keys.html -CHANNEL_OPTIONS = [('grpc.lb_policy_name', 'pick_first'), - ('grpc.enable_retries', True), - ('grpc.keepalive_timeout_ms', 10000)] +CHANNEL_OPTIONS = [ + ("grpc.lb_policy_name", "pick_first"), + ("grpc.enable_retries", True), + ("grpc.keepalive_timeout_ms", 10000), +] # based on https://github.com/grpc/grpc/blob/master/examples/python/helloworld/async_greeter_client_with_options.py + async def run(endpoint, service_id) -> None: - - async with grpc.aio.insecure_channel(target=endpoint, - options=CHANNEL_OPTIONS) as channel: + + async with grpc.aio.insecure_channel( + target=endpoint, options=CHANNEL_OPTIONS + ) as channel: stub = DetectorStub(channel) # generate data @@ -31,40 +35,36 @@ async def run(endpoint, service_id) -> None: for _ in range(200): grpc_sample = Sample() for __ in range(20): - grpc_sample.features.append(random.uniform(0., 10.)) + grpc_sample.features.append(random.uniform(0.0, 10.0)) request.samples.append(grpc_sample) for _ in range(100): grpc_sample = Sample() for __ in range(20): - grpc_sample.features.append(random.uniform(50., 60.)) + grpc_sample.features.append(random.uniform(50.0, 60.0)) request.samples.append(grpc_sample) for _ in range(alien_samples): grpc_sample = Sample() for __ in range(20): - grpc_sample.features.append(random.uniform(5000., 6000.)) + grpc_sample.features.append(random.uniform(5000.0, 6000.0)) request.samples.append(grpc_sample) # Timeout in seconds. # Please refer gRPC Python documents for more detail. https://grpc.io/grpc/python/grpc.html - response: DetectionResponse = await stub.Detect(request, - timeout=10) + response: DetectionResponse = await stub.Detect(request, timeout=10) print("Greeter client received:", service_id) return service_id * 2 async def main_loop(): - host = get_setting('DBSCANSERVINGSERVICE_SERVICE_HOST') - port = get_setting('DBSCANSERVINGSERVICE_SERVICE_PORT_GRPC') - endpoint = '{:s}:{:s}'.format(str(host), str(port)) + host = get_setting("DBSCANSERVINGSERVICE_SERVICE_HOST") + port = get_setting("DBSCANSERVINGSERVICE_SERVICE_PORT_GRPC") + endpoint = "{:s}:{:s}".format(str(host), str(port)) - ret = await asyncio.gather( - run(endpoint, 1), - run(endpoint, 2) - ) + ret = await asyncio.gather(run(endpoint, 1), run(endpoint, 2)) print(ret) -if __name__ == '__main__': - +if __name__ == "__main__": + logging.basicConfig() - asyncio.run(main_loop()) \ No newline at end of file + asyncio.run(main_loop()) diff --git a/src/opticalattackmanager/service/__main__.py b/src/opticalattackmanager/service/__main__.py index 63edd58c2923aa6fba9060a5c4ced867423deeb2..589dd67b6edd6297c62cea80061fae0a7a546597 100644 --- a/src/opticalattackmanager/service/__main__.py +++ b/src/opticalattackmanager/service/__main__.py @@ -1,155 +1,264 @@ -import asyncio, grpc, random -from common.proto.optical_attack_detector_pb2_grpc import OpticalAttackDetectorServiceStub -import logging, signal, sys, time, threading +import asyncio +import logging +import random +import signal +import sys +import threading +import time from multiprocessing import Manager, Process from typing import List -from prometheus_client import start_http_server -from common.Settings import get_log_level, get_metrics_port, get_setting -from common.proto.context_pb2 import ContextIdList, Empty, EventTypeEnum, ServiceId, ServiceIdList -from context.client.ContextClient import ContextClient -from opticalattackmanager.Config import MONITORING_INTERVAL -from common.proto.monitoring_pb2 import KpiDescriptor +from grpclib.client import Channel + +from common.Constants import ServiceNameEnum +from common.proto.context_pb2 import ( + ContextIdList, + Empty, + EventTypeEnum, + ServiceIdList, +) from common.proto.kpi_sample_types_pb2 import KpiSampleType +from common.proto.monitoring_pb2 import KpiDescriptor +from common.proto.asyncio.optical_attack_detector_grpc import ( + OpticalAttackDetectorServiceStub, +) +from common.proto.asyncio.context_pb2 import ( + ServiceId, +) +from common.Settings import ( + ENVVAR_SUFIX_SERVICE_HOST, + ENVVAR_SUFIX_SERVICE_PORT_GRPC, + get_env_var_name, + get_log_level, + get_metrics_port, + get_setting, + wait_for_environment_variables, +) +from context.client.ContextClient import ContextClient from monitoring.client.MonitoringClient import MonitoringClient +from opticalattackmanager.Config import MONITORING_INTERVAL +from opticalattackmanager.utils.EventsCollector import EventsCollector +from prometheus_client import start_http_server terminate = threading.Event() LOGGER = None -# For more channel options, please see https://grpc.io/grpc/core/group__grpc__arg__keys.html -CHANNEL_OPTIONS = [('grpc.lb_policy_name', 'pick_first'), - ('grpc.enable_retries', True), - ('grpc.keepalive_timeout_ms', 10000)] +# For more channel options, please see: +# https://grpc.io/grpc/core/group__grpc__arg__keys.html +# CHANNEL_OPTIONS = [ +# ("grpc.lb_policy_name", "pick_first"), +# ("grpc.enable_retries", True), +# ("grpc.keepalive_timeout_ms", 10000), +# ] # TODO: configure retries -def signal_handler(signal, frame): # pylint: disable=redefined-outer-name - LOGGER.warning('Terminate signal received') + +def signal_handler(signal, frame): # pylint: disable=redefined-outer-name + LOGGER.warning("Terminate signal received") terminate.set() -async def detect_attack(endpoint, context_id, service_id): +async def detect_attack(host: str, port: int, context_id: str, service_id: str) -> None: try: - async with grpc.aio.insecure_channel(target=endpoint, - options=CHANNEL_OPTIONS) as channel: + LOGGER.info("Sending request for {}...".format(service_id)) + async with Channel(host, port) as channel: + # async with grpc.aio.insecure_channel( + # target=endpoint, options=CHANNEL_OPTIONS + # ) as channel: stub = OpticalAttackDetectorServiceStub(channel) service = ServiceId() service.context_id.context_uuid.uuid = context_id service.service_uuid.uuid = str(service_id) # Timeout in seconds. - # Please refer gRPC Python documents for more detail. https://grpc.io/grpc/python/grpc.html - await stub.DetectAttack(service, - timeout=10) + # Please refer gRPC Python documents for more detail. + # https://grpc.io/grpc/python/grpc.html + await stub.DetectAttack(service, timeout=10) LOGGER.info("Monitoring finished for {}".format(service_id)) except Exception as e: + LOGGER.warning("Exception while processing service_id {}".format(service_id)) LOGGER.exception(e) async def monitor_services(service_list: List[ServiceId]): - monitoring_interval = int(get_setting('MONITORING_INTERVAL', default=MONITORING_INTERVAL)) + monitoring_interval = int( + get_setting("MONITORING_INTERVAL", default=MONITORING_INTERVAL) + ) - host = get_setting('OPTICALATTACKDETECTORSERVICE_SERVICE_HOST') - port = get_setting('OPTICALATTACKDETECTORSERVICE_SERVICE_PORT_GRPC') - endpoint = '{:s}:{:s}'.format(str(host), str(port)) + host = get_setting("OPTICALATTACKDETECTORSERVICE_SERVICE_HOST") + port = int(get_setting("OPTICALATTACKDETECTORSERVICE_SERVICE_PORT_GRPC")) - LOGGER.info('Starting execution of the async loop') + LOGGER.info("Starting execution of the async loop") while not terminate.is_set(): - LOGGER.info('Starting new monitoring cycle...') + if len(service_list) == 0: + LOGGER.debug("No services to monitor...") + time.sleep(monitoring_interval) + continue + + LOGGER.info("Starting new monitoring cycle...") start_time = time.time() tasks = [] for service in service_list: - aw = detect_attack(endpoint, service['context'], service['service']) + aw = detect_attack(host, port, service["context"], service["service"]) tasks.append(aw) [await aw for aw in tasks] - - end_time = time.time() - - diff = end_time - start_time - LOGGER.info('Monitoring loop with {} services took {:.3f} seconds ({:.2f}%)... Waiting for {:.2f} seconds...'.format(len(service_list), diff, (diff / monitoring_interval) * 100, monitoring_interval - diff)) - if diff / monitoring_interval > 0.9: - LOGGER.warning('Monitoring loop is taking {} % of the desired time ({} seconds)'.format((diff / monitoring_interval) * 100, monitoring_interval)) + end_time = time.time() - time.sleep(monitoring_interval - diff) + time_taken = end_time - start_time + LOGGER.info( + "Monitoring loop with {} services took {:.3f} seconds ({:.2f}%)... " + "Waiting for {:.2f} seconds...".format( + len(service_list), + time_taken, + (time_taken / monitoring_interval) * 100, + monitoring_interval - time_taken, + ) + ) + + if time_taken / monitoring_interval > 0.9: + LOGGER.warning( + "Monitoring loop is taking {} % of the desired time " + "({} seconds)".format( + (time_taken / monitoring_interval) * 100, monitoring_interval + ) + ) + if monitoring_interval - time_taken > 0: + time.sleep(monitoring_interval - time_taken) def create_kpi(client: MonitoringClient, service_id): # create kpi kpi_description: KpiDescriptor = KpiDescriptor() - kpi_description.kpi_description = 'Security status of service {}'.format(service_id) + kpi_description.kpi_description = "Security status of service {}".format(service_id) kpi_description.service_id.service_uuid.uuid = service_id kpi_description.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN new_kpi = client.SetKpi(kpi_description) - LOGGER.info('Created KPI {}...'.format(new_kpi.kpi_id)) + LOGGER.info("Created KPI {}...".format(new_kpi.kpi_id)) return new_kpi def get_context_updates(service_list: List[ServiceId]): # to make sure we are thread safe... - LOGGER.info('Connecting with context and monitoring components...') + LOGGER.info("Connecting with context and monitoring components...") context_client: ContextClient = ContextClient() monitoring_client: MonitoringClient = MonitoringClient() - LOGGER.info('Connected successfully... Waiting for events...') + + events_collector: EventsCollector = EventsCollector(context_client) + events_collector.start() + + LOGGER.info("Connected successfully... Waiting for events...") for service in service_list: - kpi_id = create_kpi(monitoring_client, service['service']) + kpi_id = create_kpi(monitoring_client, service["service"]) time.sleep(20) - for event in context_client.GetServiceEvents(Empty()): - LOGGER.info('Event received: {}'.format(event)) + while not terminate.wait(timeout=1): + event = events_collector.get_event(block=True, timeout=1) + if event is None: + LOGGER.info("No event received") + continue # no event received + LOGGER.info("Event received: {}".format(event)) if event.event.event_type == EventTypeEnum.EVENTTYPE_CREATE: - LOGGER.info('Service created: {}'.format(event.service_id)) + LOGGER.info("Service created: {}".format(event.service_id)) kpi_id = create_kpi(monitoring_client, event.service_id.service_uuid.uuid) - service_list.append({'context': event.service_id.context_id.context_uuid.uuid, 'service': event.service_id.service_uuid.uuid, 'kpi': kpi_id.kpi_id.uuid}) - + service_list.append( + { + "context": event.service_id.context_id.context_uuid.uuid, + "service": event.service_id.service_uuid.uuid, + "kpi": kpi_id.kpi_id.uuid, + } + ) + elif event.event.event_type == EventTypeEnum.EVENTTYPE_REMOVE: - LOGGER.info('Service removed: {}'.format(event.service_id)) + LOGGER.info("Service removed: {}".format(event.service_id)) # find service and remove it from the list of currently monitored for service in service_list: - if service['service'] == event.service_id.service_uuid.uuid and service['context'] == event.service_id.context_id.context_uuid.uuid: + if ( + service["service"] == event.service_id.service_uuid.uuid + and service["context"] + == event.service_id.context_id.context_uuid.uuid + ): service_list.remove(service) break - # service_list.remove({'context': event.service_id.context_id.context_uuid.uuid, 'service': event.service_id.service_uuid.uuid}) - - if terminate.is_set(): # if terminate is set - LOGGER.warning('Stopping execution of the get_context_updates...') - context_client.close() - monitoring_client.close() - break # break the while and stop execution - LOGGER.debug('Waiting for next event...') + + events_collector.stop() + + # for event in context_client.GetServiceEvents(Empty()): + # LOGGER.info("Event received: {}".format(event)) + # if event.event.event_type == EventTypeEnum.EVENTTYPE_CREATE: + # LOGGER.info("Service created: {}".format(event.service_id)) + # kpi_id = create_kpi(monitoring_client, event.service_id.service_uuid.uuid) + # service_list.append( + # { + # "context": event.service_id.context_id.context_uuid.uuid, + # "service": event.service_id.service_uuid.uuid, + # "kpi": kpi_id.kpi_id.uuid, + # } + # ) + + # elif event.event.event_type == EventTypeEnum.EVENTTYPE_REMOVE: + # LOGGER.info("Service removed: {}".format(event.service_id)) + # # find service and remove it from the list of currently monitored + # for service in service_list: + # if ( + # service["service"] == event.service_id.service_uuid.uuid + # and service["context"] + # == event.service_id.context_id.context_uuid.uuid + # ): + # service_list.remove(service) + # break + + # if terminate.is_set(): # if terminate is set + # LOGGER.warning("Stopping execution of the get_context_updates...") + # context_client.close() + # monitoring_client.close() + # break # break the while and stop execution + # LOGGER.debug("Waiting for next event...") def main(): - global LOGGER # pylint: disable=global-statement + global LOGGER # pylint: disable=global-statement log_level = get_log_level() logging.basicConfig(level=log_level) LOGGER = logging.getLogger(__name__) - signal.signal(signal.SIGINT, signal_handler) + wait_for_environment_variables( + [ + get_env_var_name( + ServiceNameEnum.OPTICALATTACKDETECTOR, ENVVAR_SUFIX_SERVICE_HOST + ), + get_env_var_name( + ServiceNameEnum.OPTICALATTACKDETECTOR, ENVVAR_SUFIX_SERVICE_PORT_GRPC + ), + ] + ) + + signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - LOGGER.info('Starting...') + LOGGER.info("Starting...") # Start metrics server metrics_port = get_metrics_port() # start_http_server(metrics_port) # TODO: uncomment this line - LOGGER.info('Connecting with context component...') + LOGGER.info("Connecting with context component...") context_client: ContextClient = ContextClient() - context_client.connect() - LOGGER.info('Connected successfully...') + monitoring_client: MonitoringClient = MonitoringClient() + LOGGER.info("Connected successfully...") # creating a thread-safe list to be shared among threads service_list = Manager().list() - service_list.append({'context': 'admin', 'service': '1213'}) - service_list.append({'context': 'admin', 'service': '1456'}) + service_list.append({"context": "admin", "service": "1213"}) + service_list.append({"context": "admin", "service": "1456"}) context_ids: ContextIdList = context_client.ListContextIds(Empty()) @@ -157,29 +266,40 @@ def main(): for context_id in context_ids.context_ids: context_services: ServiceIdList = context_client.ListServiceIds(context_id) for service in context_services.service_ids: - kpi_id = create_kpi(service.service_uuid.uuid) - service_list.append({'context': context_id.context_uuid.uuid, 'service': service.service_uuid.uuid, 'kpi': kpi_id}) - + kpi_id = create_kpi(monitoring_client, service.service_uuid.uuid) + service_list.append( + { + "context": context_id.context_uuid.uuid, + "service": service.service_uuid.uuid, + "kpi": kpi_id.kpi_id.uuid, + } + ) + context_client.close() + monitoring_client.close() # starting background process to monitor service addition/removal process_context = Process(target=get_context_updates, args=(service_list,)) process_context.start() + time.sleep(5) # wait for the context updates to startup + # runs the async loop in the background loop = asyncio.get_event_loop() loop.run_until_complete(monitor_services(service_list)) # asyncio.create_task(monitor_services(service_list)) # Wait for Ctrl+C or termination signal - while not terminate.wait(timeout=0.1): pass + while not terminate.wait(timeout=0.1): + pass - LOGGER.info('Terminating...') + LOGGER.info("Terminating...") process_context.kill() # process_security_loop.kill() - LOGGER.info('Bye') + LOGGER.info("Bye") return 0 -if __name__ == '__main__': + +if __name__ == "__main__": sys.exit(main()) diff --git a/src/opticalattackmanager/utils/EventsCollector.py b/src/opticalattackmanager/utils/EventsCollector.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d8ffcad9a719d2ce2a95a86339952356dd3892 --- /dev/null +++ b/src/opticalattackmanager/utils/EventsCollector.py @@ -0,0 +1,83 @@ +# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import queue +import threading + +import grpc +from common.proto.context_pb2 import Empty +from common.tools.grpc.Tools import grpc_message_to_json_string +from context.client.ContextClient import ContextClient + +LOGGER = logging.getLogger(__name__) +LOGGER.setLevel(logging.DEBUG) + + +class EventsCollector: + def __init__( + self, context_client_grpc: ContextClient, log_events_received=False + ) -> None: + self._events_queue = queue.Queue() + self._log_events_received = log_events_received + + self._service_stream = context_client_grpc.GetServiceEvents(Empty()) + + self._service_thread = threading.Thread( + target=self._collect, args=(self._service_stream,), daemon=False + ) + + def _collect(self, events_stream) -> None: + try: + for event in events_stream: + if self._log_events_received: + LOGGER.info( + "[_collect] event: {:s}".format( + grpc_message_to_json_string(event) + ) + ) + self._events_queue.put_nowait(event) + except grpc.RpcError as e: + if e.code() != grpc.StatusCode.CANCELLED: # pylint: disable=no-member + raise # pragma: no cover + + def start(self): + self._service_thread.start() + + def get_event(self, block: bool = True, timeout: float = 0.1): + try: + return self._events_queue.get(block=block, timeout=timeout) + except queue.Empty: # pylint: disable=catching-non-exception + return None + + def get_events(self, block: bool = True, timeout: float = 0.1, count: int = None): + events = [] + if count is None: + while True: + event = self.get_event(block=block, timeout=timeout) + if event is None: + break + events.append(event) + else: + for _ in range(count): + event = self.get_event(block=block, timeout=timeout) + if event is None: + continue + events.append(event) + return sorted(events, key=lambda e: e.event.timestamp.timestamp) + + def stop(self): + self._service_stream.cancel() + + self._service_thread.join() diff --git a/src/opticalattackmanager/utils/__init__.py b/src/opticalattackmanager/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/opticalattackmitigator/Config.py b/src/opticalattackmitigator/Config.py index 130381d8bd1db49803aefa992435808bed3a87d3..8e01631bbcd95e6d29bff75449b6e06bd71828ce 100644 --- a/src/opticalattackmitigator/Config.py +++ b/src/opticalattackmitigator/Config.py @@ -19,8 +19,6 @@ LOG_LEVEL = logging.DEBUG # gRPC settings GRPC_SERVICE_PORT = 10007 -GRPC_MAX_WORKERS = 10 -GRPC_GRACE_PERIOD = 60 # Prometheus settings METRICS_PORT = 9192 diff --git a/src/opticalattackmitigator/__init__.py b/src/opticalattackmitigator/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/opticalattackmitigator/__init__.py +++ b/src/opticalattackmitigator/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/opticalattackmitigator/client/OpticalAttackMitigatorClient.py b/src/opticalattackmitigator/client/OpticalAttackMitigatorClient.py index ad2ff7928af0564a7f8cd2c79a491807a57e7a74..f15b7e2c1905bfb6ae34898dc8c6ed8cfd1629c5 100644 --- a/src/opticalattackmitigator/client/OpticalAttackMitigatorClient.py +++ b/src/opticalattackmitigator/client/OpticalAttackMitigatorClient.py @@ -12,40 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -import grpc, logging -from common.Settings import get_setting -from common.tools.client.RetryDecorator import retry, delay_exponential +import logging + +import grpc from common.proto.optical_attack_mitigator_pb2 import AttackDescription, AttackResponse from common.proto.optical_attack_mitigator_pb2_grpc import AttackMitigatorStub +from common.Settings import get_setting +from common.tools.client.RetryDecorator import delay_exponential, retry +from common.tools.grpc.Tools import grpc_message_to_json LOGGER = logging.getLogger(__name__) MAX_RETRIES = 15 DELAY_FUNCTION = delay_exponential(initial=0.01, increment=2.0, maximum=5.0) -RETRY_DECORATOR = retry(max_retries=MAX_RETRIES, delay_function=DELAY_FUNCTION, prepare_method_name='connect') +RETRY_DECORATOR = retry( + max_retries=MAX_RETRIES, + delay_function=DELAY_FUNCTION, + prepare_method_name="connect", +) + class OpticalAttackMitigatorClient: def __init__(self, host=None, port=None): - if not host: host = get_setting('OPTICALATTACKMITIGATORSERVICE_SERVICE_HOST', default="DBSCANSERVING") - if not port: port = get_setting('OPTICALATTACKMITIGATORSERVICE_SERVICE_PORT_GRPC', default=10007) - self.endpoint = '{:s}:{:s}'.format(str(host), str(port)) - LOGGER.debug('Creating channel to {:s}...'.format(str(self.endpoint))) + if not host: + host = get_setting( + "OPTICALATTACKMITIGATORSERVICE_SERVICE_HOST", default="DBSCANSERVING" + ) + if not port: + port = get_setting( + "OPTICALATTACKMITIGATORSERVICE_SERVICE_PORT_GRPC", default=10007 + ) + self.endpoint = "{:s}:{:s}".format(str(host), str(port)) + LOGGER.debug("Creating channel to {:s}...".format(str(self.endpoint))) self.channel = None self.stub = None self.connect() - LOGGER.debug('Channel created') + LOGGER.debug("Channel created") def connect(self): self.channel = grpc.insecure_channel(self.endpoint) self.stub = AttackMitigatorStub(self.channel) def close(self): - if(self.channel is not None): self.channel.close() + if self.channel is not None: + self.channel.close() self.channel = None self.stub = None @RETRY_DECORATOR - def NotifyAttack(self, request : AttackDescription) -> AttackResponse: - LOGGER.debug('NotifyAttack request: {:s}'.format(str(request))) + def NotifyAttack(self, request: AttackDescription) -> AttackResponse: + LOGGER.debug("NotifyAttack request: {:s}".format(str(grpc_message_to_json(request)))) response = self.stub.NotifyAttack(request) - LOGGER.debug('NotifyAttack result: {:s}'.format(str(response))) + LOGGER.debug("NotifyAttack result: {:s}".format(str(grpc_message_to_json(response)))) return response diff --git a/src/opticalattackmitigator/client/__init__.py b/src/opticalattackmitigator/client/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/opticalattackmitigator/client/__init__.py +++ b/src/opticalattackmitigator/client/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/opticalattackmitigator/service/OpticalAttackMitigatorService.py b/src/opticalattackmitigator/service/OpticalAttackMitigatorService.py index a8bce0e581107456ac2b308ddc12182cb7355ef5..b688e26e9d8e482336d5fa6f92777a0d66efb363 100644 --- a/src/opticalattackmitigator/service/OpticalAttackMitigatorService.py +++ b/src/opticalattackmitigator/service/OpticalAttackMitigatorService.py @@ -12,25 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import grpc import logging from concurrent import futures -from grpc_health.v1.health import HealthServicer, OVERALL_HEALTH + +import grpc +from common.proto.optical_attack_mitigator_pb2_grpc import ( + add_AttackMitigatorServicer_to_server, +) +from grpc_health.v1.health import OVERALL_HEALTH, HealthServicer from grpc_health.v1.health_pb2 import HealthCheckResponse from grpc_health.v1.health_pb2_grpc import add_HealthServicer_to_server -from common.proto.optical_attack_mitigator_pb2_grpc import ( - add_AttackMitigatorServicer_to_server) +from common.Constants import ( + DEFAULT_GRPC_BIND_ADDRESS, + DEFAULT_GRPC_GRACE_PERIOD, + DEFAULT_GRPC_MAX_WORKERS, +) +from opticalattackmitigator.Config import GRPC_SERVICE_PORT from opticalattackmitigator.service.OpticalAttackMitigatorServiceServicerImpl import ( - OpticalAttackMitigatorServiceServicerImpl) -from opticalattackmitigator.Config import GRPC_SERVICE_PORT, GRPC_MAX_WORKERS, GRPC_GRACE_PERIOD + OpticalAttackMitigatorServiceServicerImpl, +) -BIND_ADDRESS = '0.0.0.0' +BIND_ADDRESS = "0.0.0.0" LOGGER = logging.getLogger(__name__) + class OpticalAttackMitigatorService: def __init__( - self, address=BIND_ADDRESS, port=GRPC_SERVICE_PORT, max_workers=GRPC_MAX_WORKERS, - grace_period=GRPC_GRACE_PERIOD): + self, + address=DEFAULT_GRPC_BIND_ADDRESS, + port=GRPC_SERVICE_PORT, + max_workers=DEFAULT_GRPC_MAX_WORKERS, + grace_period=DEFAULT_GRPC_GRACE_PERIOD, + ): self.address = address self.port = port @@ -43,30 +56,43 @@ class OpticalAttackMitigatorService: self.server = None def start(self): - self.endpoint = '{:s}:{:s}'.format(str(self.address), str(self.port)) - LOGGER.debug('Starting Service (tentative endpoint: {:s}, max_workers: {:s})...'.format( - str(self.endpoint), str(self.max_workers))) + self.endpoint = "{:s}:{:s}".format(str(self.address), str(self.port)) + LOGGER.debug( + "Starting Service (tentative endpoint: {:s}, max_workers: {:s})...".format( + str(self.endpoint), str(self.max_workers) + ) + ) self.pool = futures.ThreadPoolExecutor(max_workers=self.max_workers) - self.server = grpc.server(self.pool) # , interceptors=(tracer_interceptor,)) + self.server = grpc.server(self.pool) # , interceptors=(tracer_interceptor,)) self.attack_mitigator_servicer = OpticalAttackMitigatorServiceServicerImpl() - add_AttackMitigatorServicer_to_server(self.attack_mitigator_servicer, self.server) + add_AttackMitigatorServicer_to_server( + self.attack_mitigator_servicer, self.server + ) self.health_servicer = HealthServicer( - experimental_non_blocking=True, experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=1)) + experimental_non_blocking=True, + experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=1), + ) add_HealthServicer_to_server(self.health_servicer, self.server) port = self.server.add_insecure_port(self.endpoint) - self.endpoint = '{:s}:{:s}'.format(str(self.address), str(port)) - LOGGER.info('Listening on {:s}...'.format(self.endpoint)) + self.endpoint = "{:s}:{:s}".format(str(self.address), str(port)) + LOGGER.info("Listening on {:s}...".format(self.endpoint)) self.server.start() - self.health_servicer.set(OVERALL_HEALTH, HealthCheckResponse.SERVING) # pylint: disable=maybe-no-member + self.health_servicer.set( + OVERALL_HEALTH, HealthCheckResponse.SERVING + ) # pylint: disable=maybe-no-member - LOGGER.debug('Service started') + LOGGER.debug("Service started") def stop(self): - LOGGER.debug('Stopping service (grace period {:s} seconds)...'.format(str(self.grace_period))) + LOGGER.debug( + "Stopping service (grace period {:s} seconds)...".format( + str(self.grace_period) + ) + ) self.health_servicer.enter_graceful_shutdown() self.server.stop(self.grace_period) - LOGGER.debug('Service stopped') + LOGGER.debug("Service stopped") diff --git a/src/opticalattackmitigator/service/OpticalAttackMitigatorServiceServicerImpl.py b/src/opticalattackmitigator/service/OpticalAttackMitigatorServiceServicerImpl.py index e2ddcd62ea5aa219df36c6adb37f10853ff6ad89..8330527644200d7ac46336a60baebc9694eb3b43 100644 --- a/src/opticalattackmitigator/service/OpticalAttackMitigatorServiceServicerImpl.py +++ b/src/opticalattackmitigator/service/OpticalAttackMitigatorServiceServicerImpl.py @@ -12,28 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -import grpc, logging -from common.rpc_method_wrapper.Decorator import create_metrics, safe_and_metered_rpc_method -from common.proto.optical_attack_mitigator_pb2_grpc import ( - AttackMitigatorServicer) +import logging + +import grpc from common.proto.optical_attack_mitigator_pb2 import AttackDescription, AttackResponse +from common.proto.optical_attack_mitigator_pb2_grpc import AttackMitigatorServicer +from common.rpc_method_wrapper.Decorator import ( + create_metrics, + safe_and_metered_rpc_method, +) LOGGER = logging.getLogger(__name__) -SERVICE_NAME = 'OpticalAttackMitigator' -METHOD_NAMES = ['NotifyAttack'] +SERVICE_NAME = "OpticalAttackMitigator" +METHOD_NAMES = ["NotifyAttack"] METRICS = create_metrics(SERVICE_NAME, METHOD_NAMES) class OpticalAttackMitigatorServiceServicerImpl(AttackMitigatorServicer): - def __init__(self): - LOGGER.debug('Creating Servicer...') - LOGGER.debug('Servicer Created') + LOGGER.debug("Creating Servicer...") + LOGGER.debug("Servicer Created") @safe_and_metered_rpc_method(METRICS, LOGGER) - def NotifyAttack(self, request : AttackDescription, context : grpc.ServicerContext) -> AttackResponse: + def NotifyAttack( + self, request: AttackDescription, context: grpc.ServicerContext + ) -> AttackResponse: LOGGER.debug(f"NotifyAttack: {request}") response: AttackResponse = AttackResponse() - response.response_strategy_description = 'The AttackMitigator has received the attack description.' + response.response_strategy_description = ( + "The AttackMitigator has received the attack description." + ) return response diff --git a/src/opticalattackmitigator/service/__init__.py b/src/opticalattackmitigator/service/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/opticalattackmitigator/service/__init__.py +++ b/src/opticalattackmitigator/service/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/opticalattackmitigator/service/__main__.py b/src/opticalattackmitigator/service/__main__.py index 2fcda6af1d1cd327934f857ffea1af86ff08118a..87486dde32507ec05931655c0e7fa69826199b5f 100644 --- a/src/opticalattackmitigator/service/__main__.py +++ b/src/opticalattackmitigator/service/__main__.py @@ -12,35 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging, signal, sys, threading +import logging +import signal +import sys +import threading + +from common.Constants import ( + DEFAULT_GRPC_GRACE_PERIOD, + DEFAULT_GRPC_MAX_WORKERS, + ServiceNameEnum, +) +from common.Settings import ( + ENVVAR_SUFIX_SERVICE_HOST, + ENVVAR_SUFIX_SERVICE_PORT_GRPC, + get_env_var_name, + get_log_level, + get_metrics_port, + get_setting, + wait_for_environment_variables, +) +from opticalattackmitigator.Config import GRPC_SERVICE_PORT +from opticalattackmitigator.service.OpticalAttackMitigatorService import ( + OpticalAttackMitigatorService, +) from prometheus_client import start_http_server -from common.Settings import get_log_level, get_metrics_port, get_setting -from opticalattackmitigator.Config import ( - GRPC_SERVICE_PORT, GRPC_MAX_WORKERS, GRPC_GRACE_PERIOD) -from opticalattackmitigator.service.OpticalAttackMitigatorService import OpticalAttackMitigatorService terminate = threading.Event() LOGGER = None -def signal_handler(signal, frame): # pylint: disable=redefined-outer-name - LOGGER.warning('Terminate signal received') + +def signal_handler(signal, frame): # pylint: disable=redefined-outer-name + LOGGER.warning("Terminate signal received") terminate.set() + def main(): - global LOGGER # pylint: disable=global-statement - + global LOGGER # pylint: disable=global-statement + log_level = get_log_level() logging.basicConfig(level=log_level) LOGGER = logging.getLogger(__name__) - service_port = get_setting('OPTICALATTACKMITIGATORSERVICE_SERVICE_PORT_GRPC', default=GRPC_SERVICE_PORT) - max_workers = get_setting('MAX_WORKERS', default=GRPC_MAX_WORKERS ) - grace_period = get_setting('GRACE_PERIOD', default=GRPC_GRACE_PERIOD) + wait_for_environment_variables( + [ + get_env_var_name(ServiceNameEnum.SERVICE, ENVVAR_SUFIX_SERVICE_HOST), + get_env_var_name(ServiceNameEnum.SERVICE, ENVVAR_SUFIX_SERVICE_PORT_GRPC), + ] + ) + + service_port = get_setting( + "OPTICALATTACKMITIGATORSERVICE_SERVICE_PORT_GRPC", default=GRPC_SERVICE_PORT + ) + max_workers = get_setting("MAX_WORKERS", default=DEFAULT_GRPC_MAX_WORKERS) + grace_period = get_setting("GRACE_PERIOD", default=DEFAULT_GRPC_GRACE_PERIOD) - signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - LOGGER.info('Starting...') + LOGGER.info("Starting...") # Start metrics server metrics_port = get_metrics_port() @@ -48,17 +77,20 @@ def main(): # Starting CentralizedCybersecurity service grpc_service = OpticalAttackMitigatorService( - port=service_port, max_workers=max_workers, grace_period=grace_period) + port=service_port, max_workers=max_workers, grace_period=grace_period + ) grpc_service.start() # Wait for Ctrl+C or termination signal - while not terminate.wait(timeout=0.1): pass + while not terminate.wait(timeout=0.1): + pass - LOGGER.info('Terminating...') + LOGGER.info("Terminating...") grpc_service.stop() - LOGGER.info('Bye') + LOGGER.info("Bye") return 0 -if __name__ == '__main__': + +if __name__ == "__main__": sys.exit(main()) diff --git a/src/opticalattackmitigator/tests/__init__.py b/src/opticalattackmitigator/tests/__init__.py index 70a33251242c51f49140e596b8208a19dd5245f7..9953c820575d42fa88351cc8de022d880ba96e6a 100644 --- a/src/opticalattackmitigator/tests/__init__.py +++ b/src/opticalattackmitigator/tests/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/src/opticalattackmitigator/tests/test_unitary.py b/src/opticalattackmitigator/tests/test_unitary.py index 002639feae8dcef793c5a313f5dbcfa4e9e54db7..97b2cd10a48950e9ae0c192bfdb65136867998b2 100644 --- a/src/opticalattackmitigator/tests/test_unitary.py +++ b/src/opticalattackmitigator/tests/test_unitary.py @@ -12,31 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging, pytest -from opticalattackmitigator.Config import GRPC_SERVICE_PORT, GRPC_MAX_WORKERS, GRPC_GRACE_PERIOD -from opticalattackmitigator.client.OpticalAttackMitigatorClient import OpticalAttackMitigatorClient -from opticalattackmitigator.service.OpticalAttackMitigatorService import OpticalAttackMitigatorService +import logging + +import pytest from common.proto.optical_attack_mitigator_pb2 import AttackDescription, AttackResponse +from opticalattackmitigator.client.OpticalAttackMitigatorClient import ( + OpticalAttackMitigatorClient, +) +from opticalattackmitigator.Config import ( + GRPC_GRACE_PERIOD, + GRPC_MAX_WORKERS, + GRPC_SERVICE_PORT, +) +from opticalattackmitigator.service.OpticalAttackMitigatorService import ( + OpticalAttackMitigatorService, +) -port = 10000 + GRPC_SERVICE_PORT # avoid privileged ports +port = 10000 + GRPC_SERVICE_PORT # avoid privileged ports LOGGER = logging.getLogger(__name__) LOGGER.setLevel(logging.DEBUG) -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def optical_attack_mitigator_service(): _service = OpticalAttackMitigatorService( - port=port, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD) + port=port, max_workers=GRPC_MAX_WORKERS, grace_period=GRPC_GRACE_PERIOD + ) _service.start() yield _service _service.stop() -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def optical_attack_mitigator_client(optical_attack_mitigator_service): - _client = OpticalAttackMitigatorClient(address='127.0.0.1', port=port) + _client = OpticalAttackMitigatorClient(address="127.0.0.1", port=port) yield _client _client.close() + def test_call_service(optical_attack_mitigator_client: OpticalAttackMitigatorClient): request = AttackDescription() optical_attack_mitigator_client.NotifyAttack(request)