Skip to content
__main__.py 7.4 KiB
Newer Older
import asyncio, grpc, random
from common.proto.optical_attack_detector_pb2_grpc import OpticalAttackDetectorServiceStub
import logging, signal, sys, time, threading
from multiprocessing import Manager, Process
from typing import List
from prometheus_client import start_http_server

from common.Settings import get_log_level, get_metrics_port, get_setting
from common.proto.context_pb2 import ContextIdList, Empty, EventTypeEnum, ServiceId, ServiceIdList
from context.client.ContextClient import ContextClient
from opticalattackmanager.Config import MONITORING_INTERVAL
from common.proto.monitoring_pb2 import KpiDescriptor
from common.proto.kpi_sample_types_pb2 import KpiSampleType
from monitoring.client.MonitoringClient import MonitoringClient

terminate = threading.Event()
LOGGER = None

# For more channel options, please see https://grpc.io/grpc/core/group__grpc__arg__keys.html
CHANNEL_OPTIONS = [('grpc.lb_policy_name', 'pick_first'),
                   ('grpc.enable_retries', True),
                   ('grpc.keepalive_timeout_ms', 10000)]
# TODO: configure retries

def signal_handler(signal, frame): # pylint: disable=redefined-outer-name
    LOGGER.warning('Terminate signal received')
    terminate.set()


async def detect_attack(endpoint, context_id, service_id):
    async with grpc.aio.insecure_channel(target=endpoint,
                                         options=CHANNEL_OPTIONS) as channel:
        stub = OpticalAttackDetectorServiceStub(channel)

        service = ServiceId()
        service.context_id.context_uuid.uuid = context_id
        service.service_uuid.uuid = str(service_id)
        # Timeout in seconds.
        # Please refer gRPC Python documents for more detail. https://grpc.io/grpc/python/grpc.html
        await stub.DetectAttack(service,
                                       timeout=10)
    print("Greeter client received:", service_id)


async def monitor_services(service_list: List[ServiceId]):

    monitoring_interval = int(get_setting('MONITORING_INTERVAL', default=MONITORING_INTERVAL))

    host = get_setting('OPTICALATTACKDETECTORSERVICE_SERVICE_HOST')
    port = get_setting('OPTICALATTACKDETECTORSERVICE_SERVICE_PORT_GRPC')
    endpoint = '{:s}:{:s}'.format(str(host), str(port))
    LOGGER.info('Starting execution of the async loop')

    while not terminate.is_set():
        
        time.sleep(monitoring_interval)

        LOGGER.info('Starting new monitoring cycle...')

        start_time = time.time()

        tasks = []
        for service in service_list:
            aw = detect_attack(endpoint, service['context'], service['service'])
            tasks.append(aw)
        [await aw for aw in tasks]
        
        end_time = time.time()

        diff = end_time - start_time
        LOGGER.info('Monitoring loop with {} services took {} seconds ({:.2f}%)... Waiting for {:.2f} seconds...'.format(len(service_list), diff, (diff / monitoring_interval) * 100, monitoring_interval - diff))

        if diff / monitoring_interval > 0.9:
            LOGGER.warning('Monitoring loop is taking {} % of the desired time ({} seconds)'.format((diff / monitoring_interval) * 100, monitoring_interval))

        time.sleep(monitoring_interval - diff)


def create_kpi(client: MonitoringClient, service_id):
    # create kpi
    kpi_description: KpiDescriptor = KpiDescriptor()
    kpi_description.kpi_description = 'Security status of service {}'.format(service_id)
    kpi_description.service_id.service_uuid.uuid = service_id
    kpi_description.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN
    new_kpi = client.SetKpi(kpi_description)
    LOGGER.info('Created KPI {}...'.format(new_kpi.kpi_id))
    return new_kpi


def get_context_updates(service_list: List[ServiceId]):
    # to make sure we are thread safe...
    LOGGER.info('Connecting with context and monitoring components...')
    context_client: ContextClient = ContextClient()
    monitoring_client: MonitoringClient = MonitoringClient()
    LOGGER.info('Connected successfully... Waiting for events...')

    time.sleep(20)

    for event in context_client.GetServiceEvents(Empty()):
        LOGGER.info('Event received: {}'.format(event))
        if event.event.event_type == EventTypeEnum.EVENTTYPE_CREATE:
            LOGGER.info('Service created: {}'.format(event.service_id))
            kpi_id = create_kpi(monitoring_client, event.service_id.service_uuid.uuid)
            service_list.append({'context': event.service_id.context_id.context_uuid.uuid, 'service': event.service_id.service_uuid.uuid, 'kpi': kpi_id.kpi_id.uuid})
            
        elif event.event.event_type == EventTypeEnum.EVENTTYPE_REMOVE:
            LOGGER.info('Service removed: {}'.format(event.service_id))
            # find service and remove it from the list of currently monitored
            for service in service_list:
                if service['service'] == event.service_id.service_uuid.uuid and service['context'] == event.service_id.context_id.context_uuid.uuid:
                    service_list.remove(service)
                    break
                    # service_list.remove({'context': event.service_id.context_id.context_uuid.uuid, 'service': event.service_id.service_uuid.uuid})
            
        if terminate.is_set():  # if terminate is set
            LOGGER.warning('Stopping execution of the get_context_updates...')
            context_client.close()
            monitoring_client.close()
            break  # break the while and stop execution
        LOGGER.debug('Waiting for next event...')


def main():
    global LOGGER # pylint: disable=global-statement

    log_level = get_log_level()
    logging.basicConfig(level=log_level)
    LOGGER = logging.getLogger(__name__)

    signal.signal(signal.SIGINT,  signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    LOGGER.info('Starting...')

    # Start metrics server
    metrics_port = get_metrics_port()
    # start_http_server(metrics_port)  # TODO: uncomment this line

    LOGGER.info('Connecting with context component...')
    context_client: ContextClient = ContextClient()
    context_client.connect()
    LOGGER.info('Connected successfully...')

    # creating a thread-safe list to be shared among threads
    service_list = Manager().list()
    service_list.append({'context': 'admin', "service": "1213"})
    service_list.append({'context': 'admin', "service": "1456"})

    context_ids: ContextIdList = context_client.ListContextIds(Empty())

    # populate with initial services
    for context_id in context_ids.context_ids:
        context_services: ServiceIdList = context_client.ListServiceIds(context_id)
        for service in context_services.service_ids:
            kpi_id = create_kpi(service.service_uuid.uuid)
            service_list.append({'context': context_id.context_uuid.uuid, 'service': service.service_uuid.uuid, 'kpi': kpi_id})
    
    context_client.close()

    # starting background process to monitor service addition/removal
    process_context = Process(target=get_context_updates, args=(service_list,))
    process_context.start()

    # runs the async loop in the background
    loop = asyncio.get_event_loop()
    loop.run_until_complete(monitor_services(service_list))
    # asyncio.create_task(monitor_services(service_list))

    # Wait for Ctrl+C or termination signal
    while not terminate.wait(timeout=0.1): pass

    LOGGER.info('Terminating...')
    process_context.kill()
    # process_security_loop.kill()

    LOGGER.info('Bye')
    return 0

if __name__ == '__main__':
    sys.exit(main())