import logging
from typing import Any, Dict, List, Optional, Tuple
from common.orm.HighLevel import get_object, get_related_objects
from common.rpc_method_wrapper.ServiceExceptions import NotFoundException
from context.client.ContextClient import ContextClient
from service.proto.context_pb2 import ConfigRule, Constraint, EndPointId
from service.service.database.ConstraintModel import get_constraints, grpc_constraints_to_raw
from service.service.database.DatabaseDeviceTools import sync_device_from_context
from service.service.database.EndPointModel import EndPointModel, grpc_endpointids_to_raw
from .database.ConfigModel import ORM_ConfigActionEnum, get_config_rules, grpc_config_rules_to_raw
from .database.DeviceModel import DeviceModel, DriverModel
from .database.RelationModels import ServiceEndPointModel
from .database.ServiceModel import ServiceModel
from .service_handler_api._ServiceHandler import _ServiceHandler
from .service_handler_api.FilterFields import FilterFieldEnum
from .service_handler_api.ServiceHandlerFactory import ServiceHandlerFactory

LOGGER = logging.getLogger(__name__)

def sync_devices_from_context(
    context_client : ContextClient, db_service : ServiceModel, service_endpoint_ids : List[EndPointId]
    ) -> Dict[str, DeviceModel]:

    database = db_service.database

    required_device_uuids = set()
    db_endpoints = get_related_objects(db_service, ServiceEndPointModel, 'endpoint_fk')
    for db_endpoint in db_endpoints:
        db_device = DeviceModel(database, db_endpoint.device_fk)
        required_device_uuids.add(db_device.device_uuid)

    for endpoint_id in service_endpoint_ids:
        required_device_uuids.add(endpoint_id.device_id.device_uuid.uuid)

    db_devices = {}
    devices_not_found = set()
    for device_uuid in required_device_uuids:
        sync_device_from_context(device_uuid, context_client, database)
        db_device = get_object(database, DeviceModel, device_uuid, raise_if_not_found=False)
        if db_device is None:
            devices_not_found.add(device_uuid)
        else:
            db_devices[device_uuid] = db_device

    if len(devices_not_found) > 0:
        extra_details = ['Devices({:s}) cannot be retrieved from Context'.format(str(devices_not_found))]
        raise NotFoundException('Device', '...', extra_details=extra_details)

    return db_devices

def classify_config_rules(
    db_service : ServiceModel, service_config_rules : List[ConfigRule],
    resources_to_set: List[Tuple[str, Any]], resources_to_delete : List[Tuple[str, Any]]):

    context_config_rules = get_config_rules(db_service.database, db_service.pk, 'running')
    context_config_rules = {config_rule[1]: config_rule[2] for config_rule in context_config_rules}
    #LOGGER.info('[classify_config_rules] context_config_rules = {:s}'.format(str(context_config_rules)))

    request_config_rules = grpc_config_rules_to_raw(service_config_rules)
    #LOGGER.info('[classify_config_rules] request_config_rules = {:s}'.format(str(request_config_rules)))

    for config_rule in request_config_rules:
        action, key, value = config_rule
        if action == ORM_ConfigActionEnum.SET:
            if (key not in context_config_rules) or (context_config_rules[key] != value):
                resources_to_set.append((key, value))
        elif action == ORM_ConfigActionEnum.DELETE:
            if key in context_config_rules:
                resources_to_delete.append((key, value))

    #LOGGER.info('[classify_config_rules] resources_to_set = {:s}'.format(str(resources_to_set)))
    #LOGGER.info('[classify_config_rules] resources_to_delete = {:s}'.format(str(resources_to_delete)))

def classify_constraints(
    db_service : ServiceModel, service_constraints : List[Constraint],
    constraints_to_set: List[Tuple[str, str]], constraints_to_delete : List[Tuple[str, str]]):

    context_constraints = get_constraints(db_service.database, db_service.pk, 'running')
    context_constraints = {constraint[0]: constraint[1] for constraint in context_constraints}
    #LOGGER.info('[classify_constraints] context_constraints = {:s}'.format(str(context_constraints)))

    request_constraints = grpc_constraints_to_raw(service_constraints)
    #LOGGER.info('[classify_constraints] request_constraints = {:s}'.format(str(request_constraints)))

    for constraint in request_constraints:
        constraint_type, constraint_value = constraint
        if constraint_type in context_constraints:
            if context_constraints[constraint_type] != constraint_value:
                constraints_to_set.append(constraint)
        else:
            constraints_to_set.append(constraint)
        context_constraints.pop(constraint_type, None)

    for constraint in context_constraints:
        constraints_to_delete.append(constraint)

    #LOGGER.info('[classify_constraints] constraints_to_set = {:s}'.format(str(constraints_to_set)))
    #LOGGER.info('[classify_constraints] constraints_to_delete = {:s}'.format(str(constraints_to_delete)))

def get_service_endpointids(db_service : ServiceModel) -> List[Tuple[str, str, Optional[str]]]:
    db_endpoints : List[EndPointModel] = get_related_objects(db_service, ServiceEndPointModel, 'endpoint_fk')
    endpoint_ids = [db_endpoint.dump_id() for db_endpoint in db_endpoints]
    return [
        (endpoint_id['device_id']['device_uuid']['uuid'], endpoint_id['endpoint_uuid']['uuid'],
            endpoint_id.get('topology_id', {}).get('topology_uuid', {}).get('uuid', None))
        for endpoint_id in endpoint_ids
    ]

def classify_endpointids(
    db_service : ServiceModel, service_endpoint_ids : List[EndPointId],
    endpointids_to_set: List[Tuple[str, str, Optional[str]]],
    endpointids_to_delete : List[Tuple[str, str, Optional[str]]]):

    context_endpoint_ids = get_service_endpointids(db_service)
    #LOGGER.info('[classify_endpointids] context_endpoint_ids = {:s}'.format(str(context_endpoint_ids)))
    context_endpoint_ids = set(context_endpoint_ids)
    #LOGGER.info('[classify_endpointids] context_endpoint_ids = {:s}'.format(str(context_endpoint_ids)))

    request_endpoint_ids = grpc_endpointids_to_raw(service_endpoint_ids)
    #LOGGER.info('[classify_endpointids] request_endpoint_ids = {:s}'.format(str(request_endpoint_ids)))

    for endpoint_id in request_endpoint_ids:
        if endpoint_id not in context_endpoint_ids:
            endpointids_to_set.append(endpoint_id)
        context_endpoint_ids.discard(endpoint_id)

    for endpoint_id in context_endpoint_ids:
        endpointids_to_delete.append(endpoint_id)

    #LOGGER.info('[classify_endpointids] endpointids_to_set = {:s}'.format(str(endpointids_to_set)))
    #LOGGER.info('[classify_endpointids] endpointids_to_delete = {:s}'.format(str(endpointids_to_delete)))

def get_service_handler_class(
    service_handler_factory : ServiceHandlerFactory, db_service : ServiceModel, db_devices : Dict[str, DeviceModel]
    ) -> Optional[_ServiceHandler]:

    str_service_key = db_service.pk
    database = db_service.database

    # Assume all devices involved in the service must support at least one driver in common
    device_drivers = None
    for _,db_device in db_devices.items():
        db_driver_pks = db_device.references(DriverModel)
        db_driver_names = [DriverModel(database, pk).driver.value for pk,_ in db_driver_pks]
        if device_drivers is None:
            device_drivers = set(db_driver_names)
        else:
            device_drivers.intersection_update(db_driver_names)

    filter_fields = {
        FilterFieldEnum.SERVICE_TYPE.value  : db_service.service_type.value,    # must be supported
        FilterFieldEnum.DEVICE_DRIVER.value : device_drivers,                   # at least one must be supported
    }

    msg = 'Selecting service handler for service({:s}) with filter_fields({:s})...'
    LOGGER.info(msg.format(str(str_service_key), str(filter_fields)))
    service_handler_class = service_handler_factory.get_service_handler_class(**filter_fields)
    msg = 'ServiceHandler({:s}) selected for service({:s}) with filter_fields({:s})...'
    LOGGER.info(msg.format(str(service_handler_class.__name__), str(str_service_key), str(filter_fields)))
    return service_handler_class
