Skip to content
Snippets Groups Projects
DriverFactory.py 4.9 KiB
Newer Older
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging, operator
from enum import Enum
from typing import Any, Dict, Iterable, List, Set, Tuple
from ._Driver import _Driver
from .Exceptions import (
    UnsatisfiedFilterException, UnsupportedDriverClassException, UnsupportedFilterFieldException,
    UnsupportedFilterFieldValueException)
from .FilterFields import FILTER_FIELD_ALLOWED_VALUES, FilterFieldEnum

LOGGER = logging.getLogger(__name__)

class DriverFactory:
    def __init__(self, drivers : List[Tuple[type, List[Dict[FilterFieldEnum, Any]]]]) -> None:
        self.__indices : Dict[str, Dict[str, Set[_Driver]]] = {} # Dict{field_name => Dict{field_value => Set{Driver}}}

        for driver_class,filter_field_sets in drivers:
            for filter_fields in filter_field_sets:
                filter_fields = {k.value:v for k,v in filter_fields.items()}
                self.register_driver_class(driver_class, **filter_fields)

    def register_driver_class(self, driver_class, **filter_fields):
        if not issubclass(driver_class, _Driver): raise UnsupportedDriverClassException(str(driver_class))

        driver_name = driver_class.__name__
        supported_filter_fields = set(FILTER_FIELD_ALLOWED_VALUES.keys())
        unsupported_filter_fields = set(filter_fields.keys()).difference(supported_filter_fields)
        if len(unsupported_filter_fields) > 0:
            raise UnsupportedFilterFieldException(unsupported_filter_fields, driver_class_name=driver_name)
        for field_name, field_values in filter_fields.items():
            field_indice = self.__indices.setdefault(field_name, dict())
            field_enum_values = FILTER_FIELD_ALLOWED_VALUES.get(field_name)
            if not isinstance(field_values, Iterable) or isinstance(field_values, str):
                field_values = [field_values]
            for field_value in field_values:
                if isinstance(field_value, Enum): field_value = field_value.value
                if field_enum_values is not None and field_value not in field_enum_values:
                    raise UnsupportedFilterFieldValueException(
                        field_name, field_value, field_enum_values, driver_class_name=driver_name)
                field_indice_drivers = field_indice.setdefault(field_value, set())
                field_indice_drivers.add(driver_class)
    def get_driver_class(self, **filter_fields) -> _Driver:
        supported_filter_fields = set(FILTER_FIELD_ALLOWED_VALUES.keys())
        unsupported_filter_fields = set(filter_fields.keys()).difference(supported_filter_fields)
        if len(unsupported_filter_fields) > 0: raise UnsupportedFilterFieldException(unsupported_filter_fields)
        candidate_driver_classes : Dict[_Driver, int] = None # number of filter hits per driver
        for field_name, field_values in filter_fields.items():
            field_indice = self.__indices.get(field_name)
            if field_indice is None: continue
            field_enum_values = FILTER_FIELD_ALLOWED_VALUES.get(field_name)
            if not isinstance(field_values, Iterable) or isinstance(field_values, str):
                field_values = [field_values]

            field_candidate_driver_classes = set()
            for field_value in field_values:
                if field_enum_values is not None and field_value not in field_enum_values:
                    raise UnsupportedFilterFieldValueException(field_name, field_value, field_enum_values)
                field_indice_drivers = field_indice.get(field_value)
                if field_indice_drivers is None: continue
                field_candidate_driver_classes = field_candidate_driver_classes.union(field_indice_drivers)

            if candidate_driver_classes is None:
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
                if len(field_candidate_driver_classes) == 0: continue
                candidate_driver_classes = {k:1 for k in field_candidate_driver_classes}
            else:
                for candidate_driver_class in candidate_driver_classes:
                    if candidate_driver_class not in field_candidate_driver_classes: continue
                    candidate_driver_classes[candidate_driver_class] += 1

        if len(candidate_driver_classes) == 0: raise UnsatisfiedFilterException(filter_fields)
        candidate_driver_classes = sorted(candidate_driver_classes.items(), key=operator.itemgetter(1), reverse=True)
        return candidate_driver_classes[0][0]