from typing import Any, Dict, Iterable, List, Set, Tuple
from device.service.driver_api._Driver import _Driver
from device.service.driver_api.Exceptions import MultipleResultsForFilterException, UnsatisfiedFilterException, \
    UnsupportedDriverClassException, UnsupportedFilterFieldException, UnsupportedFilterFieldValueException
from device.service.driver_api.FilterFields import FILTER_FIELD_ALLOWED_VALUES, FilterFieldEnum

class DriverFactory:
    def __init__(self, drivers : List[Tuple[type, List[Dict[FilterFieldEnum, Any]]]]) -> None:
        self.__indices : Dict[str, Dict[str, Set[_Driver]]] = {} # Dict{field_name => Dict{field_value => Set{Driver}}}

        for driver_class,filter_field_sets in drivers:
            for filter_fields in filter_field_sets:
                self.register_driver_class(driver_class, **filter_fields)

    def register_driver_class(self, driver_class, **filter_fields):
        if not issubclass(driver_class, _Driver): raise UnsupportedDriverClassException(str(driver_class))

        driver_name = driver_class.__name__
        unsupported_filter_fields = set(filter_fields.keys()).difference(set(FILTER_FIELD_ALLOWED_VALUES.keys()))
        if len(unsupported_filter_fields) > 0:
            raise UnsupportedFilterFieldException(unsupported_filter_fields, driver_class_name=driver_name)

        for field_name, field_values in filter_fields.items():
            field_indice = self.__indices.setdefault(field_name, dict())
            field_enum_values = FILTER_FIELD_ALLOWED_VALUES.get(field_name)
            if not isinstance(field_values, Iterable) or isinstance(field_values, str):
                field_values = [field_values]
            for field_value in field_values:
                if field_enum_values is not None and field_value not in field_enum_values:
                    raise UnsupportedFilterFieldValueException(
                        field_name, field_value, field_enum_values, driver_class_name=driver_name)
                field_indice_drivers = field_indice.setdefault(field_name, set())
                field_indice_drivers.add(driver_class)

    def get_driver_class(self, **filter_fields) -> _Driver:
        unsupported_filter_fields = set(filter_fields.keys()).difference(set(FILTER_FIELD_ALLOWED_VALUES.keys()))
        if len(unsupported_filter_fields) > 0: raise UnsupportedFilterFieldException(unsupported_filter_fields)

        candidate_driver_classes = None
        for field_name, field_values in filter_fields.items():
            field_indice = self.__indices.get(field_name)
            if field_indice is None: continue
            field_enum_values = FILTER_FIELD_ALLOWED_VALUES.get(field_name)
            if not isinstance(field_values, Iterable) or isinstance(field_values, str):
                field_values = [field_values]

            field_candidate_driver_classes = set()
            for field_value in field_values:
                if field_enum_values is not None and field_value not in field_enum_values:
                    raise UnsupportedFilterFieldValueException(field_name, field_value, field_enum_values)
                field_indice_drivers = field_indice.get(field_name)
                if field_indice_drivers is None: continue
                field_candidate_driver_classes = field_candidate_driver_classes.union(field_indice_drivers)

            candidate_driver_classes = field_indice_drivers if candidate_driver_classes is None else \
                candidate_driver_classes.intersection(field_indice_drivers)

        if len(candidate_driver_classes) == 0: raise UnsatisfiedFilterException(filter_fields)
        if len(candidate_driver_classes) >  1:
            # TODO: Consider choosing driver with more query fields being satisfied (i.e., the most restrictive one)
            raise MultipleResultsForFilterException(filter_fields, {d.__name__ for d in candidate_driver_classes})
        return candidate_driver_classes.pop()
