from typing import Dict, Set
from device.service.driver_api.QueryFields import QUERY_FIELDS
from device.service.driver_api._Driver import _Driver
from device.service.driver_api.Exceptions import MultipleResultsForQueryException, UnsatisfiedQueryException, \
    UnsupportedDriverClassException, UnsupportedQueryFieldException, UnsupportedQueryFieldValueException

class DriverFactory:
    def __init__(self) -> None:
        self.__indices : Dict[str, Dict[str, Set[_Driver]]] = {} # Dict{field_name => Dict{field_value => Set{Driver}}}

    def register_driver_class(self, driver_class, **query_fields):
        if not issubclass(driver_class, _Driver): raise UnsupportedDriverClassException(str(driver_class))

        driver_name = driver_class.__name__
        unsupported_query_fields = set(query_fields.keys()).difference(set(QUERY_FIELDS.keys()))
        if len(unsupported_query_fields) > 0:
            raise UnsupportedQueryFieldException(unsupported_query_fields, driver_class_name=driver_name)

        for field_name, field_value in query_fields.items():
            field_indice = self.__indices.setdefault(field_name, dict())
            field_enum_values = QUERY_FIELDS.get(field_name)
            if field_enum_values is not None and field_value not in field_enum_values:
                raise UnsupportedQueryFieldValueException(
                    field_name, field_value, field_enum_values, driver_class_name=driver_name)
            field_indice_drivers = field_indice.setdefault(field_name, set())
            field_indice_drivers.add(driver_class)

    def get_driver_class(self, **query_fields) -> _Driver:
        unsupported_query_fields = set(query_fields.keys()).difference(set(QUERY_FIELDS.keys()))
        if len(unsupported_query_fields) > 0: raise UnsupportedQueryFieldException(unsupported_query_fields)

        candidate_driver_classes = None

        for field_name, field_value in query_fields.items():
            field_indice = self.__indices.get(field_name)
            if field_indice is None: continue
            field_enum_values = QUERY_FIELDS.get(field_name)
            if field_enum_values is not None and field_value not in field_enum_values:
                raise UnsupportedQueryFieldValueException(field_name, field_value, field_enum_values)
            field_indice_drivers = field_indice.get(field_name)
            if field_indice_drivers is None: continue

            candidate_driver_classes = set().union(field_indice_drivers) if candidate_driver_classes is None else \
                candidate_driver_classes.intersection(field_indice_drivers)

        if len(candidate_driver_classes) == 0: raise UnsatisfiedQueryException(query_fields)
        if len(candidate_driver_classes) >  1:
            # TODO: Consider choosing driver with more query fields being satisfied (i.e., the most restrictive one)
            raise MultipleResultsForQueryException(query_fields, {d.__name__ for d in candidate_driver_classes})
        return candidate_driver_classes.pop()
