# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy, logging
from typing import Dict, Optional, Tuple
from common.Constants import DEFAULT_TOPOLOGY_UUID, DOMAINS_TOPOLOGY_UUID
from common.DeviceTypes import DeviceTypeEnum
from common.proto.context_pb2 import (
    ContextId, Device, DeviceDriverEnum, DeviceId, DeviceOperationalStatusEnum, EndPoint)
from common.tools.object_factory.Device import json_device, json_device_id
from context.client.ContextClient import ContextClient
from interdomain.service.topology_abstractor.Tools import (
    add_device_to_topology, create_missing_topologies, find_own_domain_uuid, get_devices_in_topology,
    get_existing_device_uuids)

LOGGER = logging.getLogger(__name__)

class AbstractDevice:
    def __init__(self):
        self.__context_client = ContextClient()

        self.__own_context_id : Optional[ContextId] = None
        self.__own_domain_uuid : Optional[str] = None # uuid of own_context_id

        self.__own_abstract_device : Optional[Device] = None
        self.__own_abstract_device_id : Optional[DeviceId] = None

        # Dict[device_uuid, Dict[endpoint_uuid, Tuple[interdomain_endpoint_uuid, abstract EndPoint]]]
        self.__device_endpoint_to_abstract : Dict[str, Dict[str, Tuple[str, EndPoint]]] = dict()

        # Dict[interdomain_endpoint_uuid, Tuple[device_uuid, endpoint_uuid]]
        self.__abstract_to_device_endpoint : Dict[str, Tuple[str, str]] = dict()

    @property
    def own_context_id(self): return self.__own_context_id

    @property
    def own_domain_uuid(self): return self.__own_domain_uuid

    @property
    def own_abstract_device_uuid(self): return self.__own_domain_uuid

    @property
    def own_abstract_device_id(self): return self.__own_abstract_device_id

    @property
    def own_abstract_device(self): return self.__own_abstract_device

    def _load_existing_abstract_device(self) -> None:
        self.__device_endpoint_to_abstract = dict()
        self.__abstract_to_device_endpoint = dict()

        self.__own_abstract_device_id = DeviceId(**json_device_id(self.__own_domain_uuid))
        self.__own_abstract_device = self.__context_client.GetDevice(self.__own_abstract_device_id)

        # for each endpoint in own_abstract_device, populate internal data structures and mappings
        for interdomain_endpoint in self.__own_abstract_device.device_endpoints:
            interdomain_endpoint_uuid : str = interdomain_endpoint.endpoint_id.endpoint_uuid.uuid
            endpoint_uuid,device_uuid = interdomain_endpoint_uuid.split('@', maxsplit=1)

            interdomain_endpoint_tuple = (interdomain_endpoint_uuid, interdomain_endpoint)
            self.__device_endpoint_to_abstract\
                .setdefault(device_uuid, {}).setdefault(endpoint_uuid, interdomain_endpoint_tuple)
            self.__abstract_to_device_endpoint\
                .setdefault(interdomain_endpoint_uuid, (device_uuid, endpoint_uuid))

    def _create_empty_abstract_device(self) -> None:
        own_abstract_device_uuid = self.__own_domain_uuid

        own_abstract_device = Device(**json_device(
            own_abstract_device_uuid, DeviceTypeEnum.NETWORK.value,
            DeviceOperationalStatusEnum.DEVICEOPERATIONALSTATUS_ENABLED,
            endpoints=[], config_rules=[], drivers=[DeviceDriverEnum.DEVICEDRIVER_UNDEFINED]
        ))
        self.__context_client.SetDevice(own_abstract_device)
        self.__own_abstract_device = own_abstract_device
        self.__own_abstract_device_id = self.__own_abstract_device.device_id

        # Add own abstract device to topologies ["domains"]
        topology_uuids = [DOMAINS_TOPOLOGY_UUID]
        for topology_uuid in topology_uuids:
            add_device_to_topology(
                self.__context_client, self.__own_context_id, topology_uuid, own_abstract_device_uuid)

    def _discover_or_create_abstract_device(self) -> bool:
        # already discovered
        if self.__own_abstract_device is not None: return False

        # discover from existing devices; should have name of the own domain context
        existing_device_uuids = get_existing_device_uuids(self.__context_client)
        create_abstract_device = self.__own_domain_uuid not in existing_device_uuids
        if create_abstract_device:
            self._create_empty_abstract_device()
        else:
            self._load_existing_abstract_device()
        return create_abstract_device

    def _update_endpoint_type(self, device_uuid : str, endpoint_uuid : str, endpoint_type : str) -> bool:
        device_endpoint_to_abstract = self.__device_endpoint_to_abstract.get(device_uuid, {})
        interdomain_endpoint_tuple = device_endpoint_to_abstract.get(endpoint_uuid)
        _, interdomain_endpoint = interdomain_endpoint_tuple
        interdomain_endpoint_type = interdomain_endpoint.endpoint_type
        if endpoint_type == interdomain_endpoint_type: return False
        interdomain_endpoint.endpoint_type = endpoint_type
        return True

    def _add_interdomain_endpoint(
        self, device_uuid : str, endpoint_uuid : str, endpoint_type : str, interdomain_endpoint_uuid : str
    ) -> EndPoint:
        interdomain_endpoint = self.__own_abstract_device.device_endpoints.add()
        interdomain_endpoint.endpoint_id.device_id.CopyFrom(self.__own_abstract_device_id)
        interdomain_endpoint.endpoint_id.endpoint_uuid.uuid = interdomain_endpoint_uuid
        interdomain_endpoint.endpoint_type = endpoint_type

        interdomain_endpoint_tuple = (interdomain_endpoint_uuid, interdomain_endpoint)
        self.__device_endpoint_to_abstract\
            .setdefault(device_uuid, {}).setdefault(endpoint_uuid, interdomain_endpoint_tuple)
        self.__abstract_to_device_endpoint\
            .setdefault(interdomain_endpoint_uuid, (device_uuid, endpoint_uuid))

        return interdomain_endpoint

    def _remove_interdomain_endpoint(
        self, device_uuid : str, endpoint_uuid : str, interdomain_endpoint_tuple : Tuple[str, EndPoint]
    ) -> None:
        interdomain_endpoint_uuid, interdomain_endpoint = interdomain_endpoint_tuple
        self.__abstract_to_device_endpoint.pop(interdomain_endpoint_uuid, None)
        device_endpoint_to_abstract = self.__device_endpoint_to_abstract.get(device_uuid, {})
        device_endpoint_to_abstract.pop(endpoint_uuid, None)
        self.__own_abstract_device.device_endpoints.remove(interdomain_endpoint)

    def update_abstract_device_endpoints(self, device : Device) -> bool:
        device_uuid = device.device_id.device_uuid.uuid
        LOGGER
        device_border_endpoint_uuids = {
            endpoint.endpoint_id.endpoint_uuid.uuid : endpoint.endpoint_type
            for endpoint in device.device_endpoints
            if str(endpoint.endpoint_type).endswith('/border')
        }

        updated = False

        # for each border endpoint in own_abstract_device that is not in device; remove from own_abstract_device
        device_endpoint_to_abstract = self.__device_endpoint_to_abstract.get(device_uuid, {})
        _device_endpoint_to_abstract = copy.deepcopy(device_endpoint_to_abstract)
        for endpoint_uuid, interdomain_endpoint_tuple in _device_endpoint_to_abstract.items():
            if endpoint_uuid in device_border_endpoint_uuids: continue
            # remove interdomain endpoint that is not in device
            self._remove_interdomain_endpoint(device_uuid, endpoint_uuid, interdomain_endpoint_tuple)
            updated = True

        # for each border endpoint in device that is not in own_abstract_device; add to own_abstract_device
        for endpoint_uuid,endpoint_type in device_border_endpoint_uuids.items():
            # compose interdomain endpoint uuid
            interdomain_endpoint_uuid = '{:s}@{:s}'.format(endpoint_uuid, device_uuid)

            # if already added; just check endpoint type is not modified
            if interdomain_endpoint_uuid in self.__abstract_to_device_endpoint:
                updated = updated or self._update_endpoint_type(device_uuid, endpoint_uuid, endpoint_type)
                continue

            # otherwise, add it to the abstract device
            self._add_interdomain_endpoint(device_uuid, endpoint_uuid, endpoint_type, interdomain_endpoint_uuid)
            updated = True

        return updated

    def initialize(self) -> Optional[bool]:
        if self.__own_abstract_device is not None: return False

        # Discover or Create device representing abstract local domain
        self._discover_or_create_abstract_device()

        devices_in_admin_topology = get_devices_in_topology(
            self.__context_client, self.__own_context_id, DEFAULT_TOPOLOGY_UUID)
        for device in devices_in_admin_topology:
            self.update_abstract_device_endpoints(device)

        return True
