Commit 82ba0a55 authored by Lluis Gifre Renom's avatar Lluis Gifre Renom
Browse files

Service component:

- fix wrong import of TaskScheduler in _ServiceHandlerAPI interface and related files
parent 87300b7b
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ from common.rpc_method_wrapper.ServiceExceptions import AlreadyExistsException,
from common.tools.grpc.Tools import grpc_message_to_json, grpc_message_to_json_string
from context.client.ContextClient import ContextClient
from pathcomp.frontend.client.PathCompClient import PathCompClient
from service.service.tools.ContextGetters import get_service
from .tools.ContextGetters import get_service
from .service_handler_api.ServiceHandlerFactory import ServiceHandlerFactory
from .task_scheduler.TaskScheduler import TasksScheduler

+9 −6
Original line number Diff line number Diff line
@@ -14,21 +14,23 @@

import logging, operator
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
from common.proto.context_pb2 import Device, Service
from common.tools.grpc.Tools import grpc_message_to_json_string
from service.service.service_handler_api._ServiceHandler import _ServiceHandler
from .Exceptions import (
    UnsatisfiedFilterException, UnsupportedServiceHandlerClassException, UnsupportedFilterFieldException,
    UnsupportedFilterFieldValueException)
from .FilterFields import FILTER_FIELD_ALLOWED_VALUES, FilterFieldEnum

if TYPE_CHECKING:
    from service.service.service_handler_api._ServiceHandler import _ServiceHandler

LOGGER = logging.getLogger(__name__)

class ServiceHandlerFactory:
    def __init__(self, service_handlers : List[Tuple[type, List[Dict[FilterFieldEnum, Any]]]]) -> None:
        # Dict{field_name => Dict{field_value => Set{ServiceHandler}}}
        self.__indices : Dict[str, Dict[str, Set[_ServiceHandler]]] = {}
        self.__indices : Dict[str, Dict[str, Set['_ServiceHandler']]] = {}

        for service_handler_class,filter_field_sets in service_handlers:
            for filter_fields in filter_field_sets:
@@ -36,6 +38,7 @@ class ServiceHandlerFactory:
                self.register_service_handler_class(service_handler_class, **filter_fields)

    def register_service_handler_class(self, service_handler_class, **filter_fields):
        from service.service.service_handler_api._ServiceHandler import _ServiceHandler
        if not issubclass(service_handler_class, _ServiceHandler):
            raise UnsupportedServiceHandlerClassException(str(service_handler_class))

@@ -59,12 +62,12 @@ class ServiceHandlerFactory:
                field_indice_service_handlers = field_indice.setdefault(field_value, set())
                field_indice_service_handlers.add(service_handler_class)

    def get_service_handler_class(self, **filter_fields) -> _ServiceHandler:
    def get_service_handler_class(self, **filter_fields) -> '_ServiceHandler':
        supported_filter_fields = set(FILTER_FIELD_ALLOWED_VALUES.keys())
        unsupported_filter_fields = set(filter_fields.keys()).difference(supported_filter_fields)
        if len(unsupported_filter_fields) > 0: raise UnsupportedFilterFieldException(unsupported_filter_fields)

        candidate_service_handler_classes : Dict[_ServiceHandler, int] = None # num. filter hits per service_handler
        candidate_service_handler_classes : Dict['_ServiceHandler', int] = None # num. filter hits per service_handler
        for field_name, field_values in filter_fields.items():
            field_indice = self.__indices.get(field_name)
            if field_indice is None: continue
@@ -109,7 +112,7 @@ def get_common_device_drivers(drivers_per_device : List[Set[int]]) -> Set[int]:

def get_service_handler_class(
    service_handler_factory : ServiceHandlerFactory, service : Service, connection_devices : Dict[str, Device]
) -> Optional[_ServiceHandler]:
) -> Optional['_ServiceHandler']:

    str_service_key = grpc_message_to_json_string(service.service_id)

+1 −1
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@

from typing import Any, List, Optional, Tuple, Union
from common.proto.context_pb2 import Service
from service.task_scheduler.TaskExecutor import TaskExecutor
from service.service.task_scheduler.TaskExecutor import TaskExecutor

class _ServiceHandler:
    def __init__(self,
+5 −3
Original line number Diff line number Diff line
@@ -13,16 +13,18 @@
# limitations under the License.

from enum import Enum
from typing import Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from common.proto.context_pb2 import Connection, ConnectionId, Device, DeviceId, Service, ServiceId
from common.rpc_method_wrapper.ServiceExceptions import NotFoundException
from context.client.ContextClient import ContextClient
from device.client.DeviceClient import DeviceClient
from service.service.service_handler_api._ServiceHandler import _ServiceHandler
from service.service.service_handler_api.ServiceHandlerFactory import ServiceHandlerFactory, get_service_handler_class
from service.service.tools.ContextGetters import get_connection, get_device, get_service
from service.service.tools.ObjectKeys import get_connection_key, get_device_key, get_service_key

if TYPE_CHECKING:
    from service.service.service_handler_api._ServiceHandler import _ServiceHandler

CacheableObject = Union[Connection, Device, Service]

class CacheableObjectType(Enum):
@@ -136,7 +138,7 @@ class TaskExecutor:

    def get_service_handler(
        self, connection : Connection, service : Service, **service_handler_settings
    ) -> _ServiceHandler:
    ) -> '_ServiceHandler':
        connection_devices = self.get_devices_from_connection(connection)
        service_handler_class = get_service_handler_class(self._service_handler_factory, service, connection_devices)
        return service_handler_class(service, self, **service_handler_settings)
+5 −3
Original line number Diff line number Diff line
@@ -13,12 +13,11 @@
# limitations under the License.

import graphlib, logging, queue, time
from typing import Dict, Tuple
from typing import TYPE_CHECKING, Dict, Tuple
from common.proto.context_pb2 import Connection, ConnectionId, Service, ServiceId, ServiceStatusEnum
from common.proto.pathcomp_pb2 import PathCompReply
from common.tools.grpc.Tools import grpc_message_to_json_string
from context.client.ContextClient import ContextClient
from service.service.service_handler_api.ServiceHandlerFactory import ServiceHandlerFactory
from service.service.tools.ObjectKeys import get_connection_key, get_service_key
from .tasks._Task import _Task
from .tasks.Task_ConnectionConfigure import Task_ConnectionConfigure
@@ -27,10 +26,13 @@ from .tasks.Task_ServiceDelete import Task_ServiceDelete
from .tasks.Task_ServiceSetStatus import Task_ServiceSetStatus
from .TaskExecutor import CacheableObjectType, TaskExecutor

if TYPE_CHECKING:
    from service.service.service_handler_api.ServiceHandlerFactory import ServiceHandlerFactory

LOGGER = logging.getLogger(__name__)

class TasksScheduler:
    def __init__(self, service_handler_factory : ServiceHandlerFactory) -> None:
    def __init__(self, service_handler_factory : 'ServiceHandlerFactory') -> None:
        self._dag = graphlib.TopologicalSorter()
        self._executor = TaskExecutor(service_handler_factory)
        self._tasks : Dict[str, _Task] = dict()