Skip to content
__main__.py 6.73 KiB
Newer Older
import logging, signal, sys, time, threading
from multiprocessing import Manager, Process
from typing import List
from prometheus_client import start_http_server
from celery import Celery

from common.Settings import get_log_level, get_metrics_port, get_setting
from common.proto.context_pb2 import ContextIdList, Empty, EventTypeEnum, ServiceId, ServiceIdList
from context.client.ContextClient import ContextClient
from opticalattackmanager.Config import MONITORING_INTERVAL
from common.proto.monitoring_pb2 import KpiDescriptor
from common.proto.kpi_sample_types_pb2 import KpiSampleType
from monitoring.client.MonitoringClient import MonitoringClient

terminate = threading.Event()
LOGGER = None

def signal_handler(signal, frame): # pylint: disable=redefined-outer-name
    LOGGER.warning('Terminate signal received')
    terminate.set()


def monitor_services(app: Celery, service_list: List[ServiceId]):

    monitoring_interval = get_setting('MONITORING_INTERVAL', default=MONITORING_INTERVAL)

    while not terminate.is_set():
        
        time.sleep(monitoring_interval)

        LOGGER.info('Starting new monitoring cycle...')

        start_time = time.time()

        try:
            tasks = []

            for service in service_list:
                LOGGER.debug('Scheduling service: {}'.format(service))
                tasks.append(
                    app.send_task('detect_attack', (service['context'], service['service'], service['kpi']))
                )
            
            for task in tasks:
                LOGGER.debug('Waiting for task {}...'.format(task))
                result = task.get()
                LOGGER.debug('Result for task {} is {}...'.format(task, result))
        except Exception as e:
            LOGGER.exception(e)
        
        end_time = time.time()

        diff = end_time - start_time
        LOGGER.info('Monitoring loop with {} services took {} seconds...'.format(len(service_list), diff))

        if diff / monitoring_interval > 0.9:
            LOGGER.warning('Monitoring loop is taking {} % of the desired time ({} seconds)'.format((diff / monitoring_interval) * 100, monitoring_interval))

        time.sleep(monitoring_interval - diff)


def create_kpi(client: MonitoringClient, service_id):
    # create kpi
    kpi_description: KpiDescriptor = KpiDescriptor()
    kpi_description.kpi_description = 'Security status of service {}'.format(service_id)
    kpi_description.service_id.service_uuid.uuid = service_id
    kpi_description.kpi_sample_type = KpiSampleType.KPISAMPLETYPE_UNKNOWN
    new_kpi = client.SetKpi(kpi_description)
    LOGGER.info('Created KPI {}...'.format(new_kpi.kpi_id))
    return new_kpi


def get_context_updates(service_list: List[ServiceId]):
    # to make sure we are thread safe...
    LOGGER.info('Connecting with context and monitoring components...')
    context_client: ContextClient = ContextClient()
    monitoring_client: MonitoringClient = MonitoringClient()
    LOGGER.info('Connected successfully... Waiting for events...')

    time.sleep(20)

    for event in context_client.GetServiceEvents(Empty()):
        LOGGER.info('Event received: {}'.format(event))
        if event.event.event_type == EventTypeEnum.EVENTTYPE_CREATE:
            LOGGER.info('Service created: {}'.format(event.service_id))
            kpi_id = create_kpi(monitoring_client, event.service_id.service_uuid.uuid)
            service_list.append({'context': event.service_id.context_id.context_uuid.uuid, 'service': event.service_id.service_uuid.uuid, 'kpi': kpi_id.kpi_id.uuid})
            
        elif event.event.event_type == EventTypeEnum.EVENTTYPE_REMOVE:
            LOGGER.info('Service removed: {}'.format(event.service_id))
            # find service and remove it from the list of currently monitored
            for service in service_list:
                if service['service'] == event.service_id.service_uuid.uuid and service['context'] == event.service_id.context_id.context_uuid.uuid:
                    service_list.remove(service)
                    break
                    # service_list.remove({'context': event.service_id.context_id.context_uuid.uuid, 'service': event.service_id.service_uuid.uuid})
            
        if terminate.is_set():  # if terminate is set
            LOGGER.warning('Stopping execution of the get_context_updates...')
            context_client.close()
            monitoring_client.close()
            break  # break the while and stop execution
        LOGGER.debug('Waiting for next event...')


def main():
    global LOGGER # pylint: disable=global-statement

    log_level = get_log_level()
    logging.basicConfig(level=log_level)
    LOGGER = logging.getLogger(__name__)

    signal.signal(signal.SIGINT,  signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    LOGGER.info('Starting...')

    # Start metrics server
    metrics_port = get_metrics_port()
    start_http_server(metrics_port)

    LOGGER.info('Connecting with context component...')
    context_client: ContextClient = ContextClient()
    context_client.connect()
    LOGGER.info('Connected successfully...')

    LOGGER.info('Connecting with REDIS...')
    REDIS_PASSWORD = get_setting('REDIS_PASSWORD')
    REDIS_HOST = get_setting('CACHINGSERVICE_SERVICE_HOST')
    REDIS_PORT = get_setting('CACHINGSERVICE_SERVICE_PORT_REDIS')
    BROKER_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/0'
    BACKEND_URL = f'redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/1'
    app = Celery(
        'cybersecurity',
        broker=BROKER_URL,
        backend=BACKEND_URL
    )
    LOGGER.info('Connected to REDIS...')

    # creating a thread-safe list to be shared among threads
    service_list = Manager().list()

    context_ids: ContextIdList = context_client.ListContextIds(Empty())

    # populate with initial services
    for context_id in context_ids.context_ids:
        context_services: ServiceIdList = context_client.ListServiceIds(context_id)
        for service in context_services.service_ids:
            kpi_id = create_kpi(service.service_uuid.uuid)
            service_list.append({'context': context_id.context_uuid.uuid, 'service': service.service_uuid.uuid, 'kpi': kpi_id})
    
    context_client.close()

    # starting background process to monitor service addition/removal
    process_context = Process(target=get_context_updates, args=(service_list,))
    process_context.start()

    monitor_services(app, service_list)

    # process_security_loop = Process(target=monitor_services, args=(app, service_list))
    # process_security_loop.start()

    # Wait for Ctrl+C or termination signal
    while not terminate.wait(timeout=0.1): pass

    LOGGER.info('Terminating...')
    process_context.kill()
    # process_security_loop.kill()

    LOGGER.info('Bye')
    return 0

if __name__ == '__main__':
    sys.exit(main())