import logging, signal, sys, time, threading, multiprocessing
from prometheus_client import start_http_server
from common.Settings import get_setting
from opticalcentralizedattackdetector.Config import (
    GRPC_SERVICE_PORT, GRPC_MAX_WORKERS, GRPC_GRACE_PERIOD, LOG_LEVEL, METRICS_PORT,
    MONITORING_INTERVAL, CONTEXT_SERVICE_ADDRESS, SERVICE_SERVICE_ADDRESS, INFERENCE_SERVICE_ADDRESS, MONITORING_SERVICE_ADDRESS)
from context.Config import GRPC_SERVICE_PORT as CONTEXT_GRPC_SERVICE_PORT
from context.client.ContextClient import ContextClient
from opticalcentralizedattackdetector.proto.context_pb2 import (Empty,
    Context,  ContextId,  ContextIdList,  ContextList,
    Service,  ServiceId,  ServiceIdList,  ServiceList
)
# from monitoring.Config import GRPC_SERVICE_PORT as MONITORING_GRPC_SERVICE_PORT
# from monitoring.client.monitoring_client import MonitoringClient
from service.Config import GRPC_SERVICE_PORT as SERVICE_GRPC_SERVICE_PORT
from service.client.ServiceClient import ServiceClient
from opticalcentralizedattackdetector.service.OpticalCentralizedAttackDetectorService import OpticalCentralizedAttackDetectorService

terminate = threading.Event()
LOGGER = None

def signal_handler(signal, frame): # pylint: disable=redefined-outer-name
    LOGGER.warning('Terminate signal received')
    terminate.set()

def detect_attack(monitoring_interval):
    LOGGER.info("Starting the attack detection loop")
    context_client: ContextClient = ContextClient(address=CONTEXT_SERVICE_ADDRESS, port=CONTEXT_GRPC_SERVICE_PORT)
    # monitoring_client: MonitoringClient = MonitoringClient(address=MONITORING_SERVICE_ADDRESS, port=MONITORING_GRPC_SERVICE_PORT)
    service_client: ServiceClient = ServiceClient(address=SERVICE_SERVICE_ADDRESS, port=SERVICE_GRPC_SERVICE_PORT)
    while True:  # infinite loop that runs until the terminate is set
        if terminate.is_set():  # if terminate is set
            LOGGER.warning("Stopping execution...")
            context_client.close()
            service_client.close()
            break  # break the while and stop execution
        
        # retrieve list with current contexts
        # import pdb; pdb.set_trace()
        context_ids: ContextIdList = context_client.ListContextIds(Empty())

        # for each context, retrieve list of current services
        services = []
        for context_id in context_ids.context_ids:

            context_services: ServiceIdList = context_client.ListServices(context_id)
            for service in context_services.services:
                services.append(service)

        # get monitoring data for each of the current services
        for service in services:
            for endpoint in service.service_endpoint_ids:
                # get instant KPI for this endpoint
                LOGGER.warning(f'service: {service.service_id.service_uuid.uuid}\t endpoint: {endpoint.endpoint_uuid.uuid}\tdevice: {endpoint.device_id.device_uuid.uuid}')
                # how to get all KPIs for a particular device?
                pass

        # run attack detection for every service

        # if attack is detected, run the attack mitigator

        # sleep
        LOGGER.debug("Sleeping for {} seconds...".format(monitoring_interval))
        time.sleep(monitoring_interval)

def main():
    global LOGGER # pylint: disable=global-statement

    service_port = get_setting('OPTICALCENTRALIZEDATTACKDETECTORSERVICE_SERVICE_PORT_GRPC', default=GRPC_SERVICE_PORT)
    max_workers  = get_setting('MAX_WORKERS',                                               default=GRPC_MAX_WORKERS )
    grace_period = get_setting('GRACE_PERIOD',                                              default=GRPC_GRACE_PERIOD)
    log_level    = get_setting('LOG_LEVEL',                                                 default=LOG_LEVEL        )
    metrics_port = get_setting('METRICS_PORT',                                              default=METRICS_PORT     )
    monitoring_interval = get_setting('MONITORING_INTERVAL',                                              default=MONITORING_INTERVAL     )

    logging.basicConfig(level=log_level)
    LOGGER = logging.getLogger(__name__)

    signal.signal(signal.SIGINT,  signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    LOGGER.info('Starting...')

    # Start metrics server
    start_http_server(metrics_port)

    # Starting CentralizedCybersecurity service
    grpc_service = OpticalCentralizedAttackDetectorService(
        port=service_port, max_workers=max_workers, grace_period=grace_period)
    grpc_service.start()

    # p = multiprocessing.Process(target=detect_attack, args=(monitoring_interval, ))
    # p.start()
    detect_attack(monitoring_interval)

    # Wait for Ctrl+C or termination signal
    while not terminate.wait(timeout=0.1): pass

    LOGGER.info('Terminating...')
    grpc_service.stop()
    # p.kill()

    LOGGER.info('Bye')
    return 0

if __name__ == '__main__':
    sys.exit(main())
