Skip to content
__main__.py 10.4 KiB
Newer Older
import asyncio
import logging
import signal
import sys
import threading
import time
from multiprocessing import Manager, Process
from typing import List

from common.Constants import ServiceNameEnum
from common.proto.asyncio.optical_attack_detector_grpc import (
    OpticalAttackDetectorServiceStub,
)
from common.proto.asyncio.optical_attack_detector_pb2 import DetectionRequest
from common.proto.context_pb2 import ContextIdList, Empty, EventTypeEnum, ServiceIdList
from common.proto.kpi_sample_types_pb2 import KpiSampleType
from common.proto.monitoring_pb2 import KpiDescriptor
from common.Settings import (
    ENVVAR_SUFIX_SERVICE_HOST,
    ENVVAR_SUFIX_SERVICE_PORT_GRPC,
    get_env_var_name,
    get_log_level,
    get_metrics_port,
    get_setting,
    wait_for_environment_variables,
)
from common.tools.grpc.Tools import grpc_message_to_json_string
from context.client.ContextClient import ContextClient
from monitoring.client.MonitoringClient import MonitoringClient
from opticalattackmanager.Config import MONITORING_INTERVAL
from opticalattackmanager.utils.EventsCollector import EventsCollector
from prometheus_client import start_http_server, Histogram, Counter
# Create a metric to track time spent and requests made.
# TODO: adjust histogram buckets to more realistic values
LOOP_TIME = Histogram('optical_security_loop_seconds', 'Time taken by each security loop')
DROP_COUNTER = Counter('optical_security_dropped_assessments', 'Dropped assessments due to detector timeout')
def signal_handler(signal, frame):  # pylint: disable=redefined-outer-name
    LOGGER.warning("Terminate signal received")
async def detect_attack(
    host: str,
    port: int,
    context_id: str,
    service_id: str,
    kpi_id: str,
    timeout: float = 10.0,
) -> None:
        LOGGER.info("Sending request for {}...".format(service_id))
        async with Channel(host, port) as channel:
            stub = OpticalAttackDetectorServiceStub(channel)

            request: DetectionRequest = DetectionRequest()
            request.service_id.context_id.context_uuid.uuid = context_id
            request.service_id.service_uuid.uuid = str(service_id)

            request.kpi_id.kpi_id.uuid = kpi_id

            await stub.DetectAttack(request, timeout=timeout)
        LOGGER.info("Monitoring finished for {}".format(service_id))
    except Exception as e:
        LOGGER.warning("Exception while processing service_id {}".format(service_id))
        LOGGER.exception(e)
async def monitor_services(service_list: List):
    monitoring_interval = int(
        get_setting("MONITORING_INTERVAL", default=MONITORING_INTERVAL)
    )
    host = get_setting("OPTICALATTACKDETECTORSERVICE_SERVICE_HOST")
    port = int(get_setting("OPTICALATTACKDETECTORSERVICE_SERVICE_PORT_GRPC"))
    LOGGER.info("Starting execution of the async loop")
        if len(service_list) == 0:
            LOGGER.debug("No services to monitor...")
            time.sleep(monitoring_interval)
            continue

        LOGGER.info("Starting new monitoring cycle...")
        tasks = []
        for service in service_list:
            aw = detect_attack(
                host,
                port,
                service["context"],
                service["service"],
                service["kpi"],
                # allow at most 90% of the monitoring interval to succeed
                monitoring_interval * 0.9,
            )
            tasks.append(aw)
        [await aw for aw in tasks]
        LOGGER.info(
            "Monitoring loop with {} services took {:.3f} seconds ({:.2f}%)... "
            "Waiting for {:.2f} seconds...".format(
                len(service_list),
                time_taken,
                (time_taken / monitoring_interval) * 100,
                monitoring_interval - time_taken,
            )
        )

        if time_taken / monitoring_interval > 0.9:
            LOGGER.warning(
                "Monitoring loop is taking {} % of the desired time "
                "({} seconds)".format(
                    (time_taken / monitoring_interval) * 100, monitoring_interval
                )
            )
        if monitoring_interval - time_taken > 0:
            time.sleep(monitoring_interval - time_taken)


def create_kpi(client: MonitoringClient, service_id):
    # create kpi
    kpi_description: KpiDescriptor = KpiDescriptor()
    kpi_description.kpi_description = "Security status of service {}".format(service_id)
    kpi_description.service_id.service_uuid.uuid = service_id
    kpi_description.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN
    new_kpi = client.SetKpi(kpi_description)
    LOGGER.info("Created KPI {}: ".format(grpc_message_to_json_string(new_kpi)))
def get_context_updates(service_list: List):
    LOGGER.info("Connecting with context and monitoring components...")
    context_client: ContextClient = ContextClient()
    monitoring_client: MonitoringClient = MonitoringClient()

    events_collector: EventsCollector = EventsCollector(context_client)
    events_collector.start()

    LOGGER.info("Connected successfully... Waiting for events...")
    while not terminate.wait(timeout=1):
        event = events_collector.get_event(block=True, timeout=1)
        if event is None:
            LOGGER.info("No event received")
            continue  # no event received
        LOGGER.info("Event received: {}".format(grpc_message_to_json_string(event)))
        if event.event.event_type == EventTypeEnum.EVENTTYPE_CREATE:
            LOGGER.info(
                "Service created: {}".format(
                    grpc_message_to_json_string(event.service_id)
                )
            )
            kpi_id = create_kpi(monitoring_client, event.service_id.service_uuid.uuid)
            service_list.append(
                {
                    "context": event.service_id.context_id.context_uuid.uuid,
                    "service": event.service_id.service_uuid.uuid,
                    "kpi": kpi_id.kpi_id.uuid,
                }
            )

        elif event.event.event_type == EventTypeEnum.EVENTTYPE_REMOVE:
            LOGGER.info(
                "Service removed: {}".format(
                    grpc_message_to_json_string(event.service_id)
                )
            )
            # find service and remove it from the list of currently monitored
            for service in service_list:
                if (
                    service["service"] == event.service_id.service_uuid.uuid
                    and service["context"]
                    == event.service_id.context_id.context_uuid.uuid
                ):
    global LOGGER  # pylint: disable=global-statement

    log_level = get_log_level()
    logging.basicConfig(level=log_level)
    LOGGER = logging.getLogger(__name__)

    logging.getLogger("hpack").setLevel(logging.CRITICAL)

            get_env_var_name(ServiceNameEnum.MONITORING, ENVVAR_SUFIX_SERVICE_HOST),
            get_env_var_name(
                ServiceNameEnum.MONITORING, ENVVAR_SUFIX_SERVICE_PORT_GRPC
            ),
        ]
    )

    wait_for_environment_variables(
        [
            get_env_var_name(ServiceNameEnum.CONTEXT, ENVVAR_SUFIX_SERVICE_HOST),
            get_env_var_name(ServiceNameEnum.CONTEXT, ENVVAR_SUFIX_SERVICE_PORT_GRPC),
    wait_for_environment_variables(
        [
            get_env_var_name(
                ServiceNameEnum.OPTICALATTACKDETECTOR, ENVVAR_SUFIX_SERVICE_HOST
            ),
            get_env_var_name(
                ServiceNameEnum.OPTICALATTACKDETECTOR, ENVVAR_SUFIX_SERVICE_PORT_GRPC
            ),
        ]
    )

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)


    # Start metrics server
    metrics_port = get_metrics_port()
    start_http_server(metrics_port)
    LOGGER.info("Connecting with context component...")
    context_client: ContextClient = ContextClient()
    monitoring_client: MonitoringClient = MonitoringClient()
    LOGGER.info("Connected successfully...")

    # creating a thread-safe list to be shared among threads
    # TODO: comment the lines below to stop monitoring dummy services
    service_list.append(
        {"context": "admin", "service": "1213", "kpi": kpi_id.kpi_id.uuid}
    )
    service_list.append(
        {"context": "admin", "service": "1456", "kpi": kpi_id.kpi_id.uuid}
    )

    context_ids: ContextIdList = context_client.ListContextIds(Empty())

    # populate with initial services
    for context_id in context_ids.context_ids:
        context_services: ServiceIdList = context_client.ListServiceIds(context_id)
        for service in context_services.service_ids:
            # in case of a service restart, monitoring component will not duplicate KPIs
            # but rather return the existing KPI if that's the case
            kpi_id = create_kpi(monitoring_client, service.service_uuid.uuid)
            service_list.append(
                {
                    "context": context_id.context_uuid.uuid,
                    "service": service.service_uuid.uuid,
                    "kpi": kpi_id.kpi_id.uuid,
                }
            )


    # starting background process to monitor service addition/removal
    process_context = Process(target=get_context_updates, args=(service_list,))
    process_context.start()

    time.sleep(5)  # wait for the context updates to startup

    # runs the async loop in the background
    loop = asyncio.get_event_loop()
    loop.run_until_complete(monitor_services(service_list))
    # asyncio.create_task(monitor_services(service_list))
    while not terminate.wait(timeout=0.1):
        pass
    process_context.kill()
    # process_security_loop.kill()