# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from common.method_wrappers.ServiceExceptions import NotFoundException
from common.proto.context_pb2 import Connection, ConnectionId, Device, DeviceId, Service, ServiceId
from context.client.ContextClient import ContextClient
from device.client.DeviceClient import DeviceClient
from service.service.service_handler_api.ServiceHandlerFactory import ServiceHandlerFactory, get_service_handler_class
from service.service.tools.ContextGetters import get_connection, get_device, get_service
from service.service.tools.ObjectKeys import get_connection_key, get_device_key, get_service_key

if TYPE_CHECKING:
    from service.service.service_handler_api._ServiceHandler import _ServiceHandler

CacheableObject = Union[Connection, Device, Service]

class CacheableObjectType(Enum):
    CONNECTION = 'connection'
    DEVICE     = 'device'
    SERVICE    = 'service'

class TaskExecutor:
    def __init__(self, service_handler_factory : ServiceHandlerFactory) -> None:
        self._service_handler_factory = service_handler_factory
        self._context_client = ContextClient()
        self._device_client = DeviceClient()
        self._grpc_objects_cache : Dict[str, CacheableObject] = dict()

    @property
    def service_handler_factory(self) -> ServiceHandlerFactory: return self._service_handler_factory

    # ----- Common methods ---------------------------------------------------------------------------------------------

    def _load_grpc_object(self, object_type : CacheableObjectType, object_key : str) -> Optional[CacheableObject]:
        object_key = '{:s}:{:s}'.format(object_type.value, object_key)
        return self._grpc_objects_cache.get(object_key)

    def _store_grpc_object(self, object_type : CacheableObjectType, object_key : str, grpc_object) -> None:
        object_key = '{:s}:{:s}'.format(object_type.value, object_key)
        self._grpc_objects_cache[object_key] = grpc_object
    
    def _delete_grpc_object(self, object_type : CacheableObjectType, object_key : str) -> None:
        object_key = '{:s}:{:s}'.format(object_type.value, object_key)
        self._grpc_objects_cache.pop(object_key, None)

    def _store_editable_grpc_object(
        self, object_type : CacheableObjectType, object_key : str, grpc_class, grpc_ro_object
    ) -> Any:
        grpc_rw_object = grpc_class()
        grpc_rw_object.CopyFrom(grpc_ro_object)
        self._store_grpc_object(object_type, object_key, grpc_rw_object)
        return grpc_rw_object

    # ----- Connection-related methods ---------------------------------------------------------------------------------

    def get_connection(self, connection_id : ConnectionId) -> Connection:
        connection_key = get_connection_key(connection_id)
        connection = self._load_grpc_object(CacheableObjectType.CONNECTION, connection_key)
        if connection is None:
            connection = get_connection(self._context_client, connection_id)
            if connection is None: raise NotFoundException('Connection', connection_key)
            connection : Connection = self._store_editable_grpc_object(
                CacheableObjectType.CONNECTION, connection_key, Connection, connection)
        return connection

    def set_connection(self, connection : Connection) -> None:
        connection_key = get_connection_key(connection.connection_id)
        self._context_client.SetConnection(connection)
        self._store_grpc_object(CacheableObjectType.CONNECTION, connection_key, connection)

    def delete_connection(self, connection_id : ConnectionId) -> None:
        connection_key = get_connection_key(connection_id)
        self._context_client.RemoveConnection(connection_id)
        self._delete_grpc_object(CacheableObjectType.CONNECTION, connection_key)

    # ----- Device-related methods -------------------------------------------------------------------------------------

    def get_device(self, device_id : DeviceId) -> Device:
        device_key = get_device_key(device_id)
        device = self._load_grpc_object(CacheableObjectType.DEVICE, device_key)
        if device is None:
            device = get_device(self._context_client, device_id)
            if device is None: raise NotFoundException('Device', device_key)
            device : Device = self._store_editable_grpc_object(
                CacheableObjectType.DEVICE, device_key, Device, device)
        return device

    def configure_device(self, device : Device) -> None:
        device_key = get_device_key(device.device_id)
        self._device_client.ConfigureDevice(device)
        self._store_grpc_object(CacheableObjectType.DEVICE, device_key, device)

    def get_devices_from_connection(self, connection : Connection) -> Dict[str, Device]:
        devices = dict()
        for endpoint_id in connection.path_hops_endpoint_ids:
            device = self.get_device(endpoint_id.device_id)
            device_uuid = endpoint_id.device_id.device_uuid.uuid
            if device is None: raise Exception('Device({:s}) not found'.format(str(device_uuid)))
            devices[device_uuid] = device
        return devices

    # ----- Service-related methods ------------------------------------------------------------------------------------

    def get_service(self, service_id : ServiceId) -> Service:
        service_key = get_service_key(service_id)
        service = self._load_grpc_object(CacheableObjectType.SERVICE, service_key)
        if service is None:
            service = get_service(self._context_client, service_id)
            if service is None: raise NotFoundException('Service', service_key)
            service : service = self._store_editable_grpc_object(
                CacheableObjectType.SERVICE, service_key, Service, service)
        return service

    def set_service(self, service : Service) -> None:
        service_key = get_service_key(service.service_id)
        self._context_client.SetService(service)
        self._store_grpc_object(CacheableObjectType.SERVICE, service_key, service)

    def delete_service(self, service_id : ServiceId) -> None:
        service_key = get_service_key(service_id)
        self._context_client.RemoveService(service_id)
        self._delete_grpc_object(CacheableObjectType.SERVICE, service_key)

    # ----- Service Handler Factory ------------------------------------------------------------------------------------

    def get_service_handler(
        self, connection : Connection, service : Service, **service_handler_settings
    ) -> '_ServiceHandler':
        connection_devices = self.get_devices_from_connection(connection)
        service_handler_class = get_service_handler_class(self._service_handler_factory, service, connection_devices)
        return service_handler_class(service, self, **service_handler_settings)
