# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging, networkx, threading
from typing import List, Optional, Union
from common.Constants import DEFAULT_CONTEXT_UUID, DEFAULT_TOPOLOGY_UUID, DOMAINS_TOPOLOGY_UUID
from common.proto.context_pb2 import ContextEvent, ContextId, DeviceEvent, DeviceId, ServiceId, SliceId, TopologyEvent
from common.tools.grpc.Tools import grpc_message_to_json_string
from common.tools.object_factory.Context import json_context_id
from common.tools.object_factory.Device import json_device_id
from context.client.ContextClient import ContextClient
from context.client.EventsCollector import EventsCollector
from dlt.connector.client.DltConnectorClient import DltConnectorClient
from interdomain.service.topology_abstractor.Tools import create_missing_topologies, get_uuids_of_devices_in_topology
from interdomain.service.topology_abstractor.Types import DltRecordIdTypes, EventTypes
from .AbstractDevice import AbstractDevice
from .OwnDomainFinder import OwnDomainFinder

LOGGER = logging.getLogger(__name__)

class TopologyAbstractor(threading.Thread):
    def __init__(self) -> None:
        super().__init__(daemon=True)
        self.terminate = threading.Event()

        self.context_client = ContextClient()
        self.dlt_connector_client = DltConnectorClient()
        self.context_event_collector = EventsCollector(self.context_client)

        self.own_context_id : Optional[ContextId] = None
        self.own_domain_filder = OwnDomainFinder()
        self.abstract_topology = networkx.Graph()
        self.abstract_device = AbstractDevice()

    def stop(self):
        self.terminate.set()

    def run(self) -> None:
        self.context_client.connect()
        self.dlt_connector_client.connect()
        self.context_event_collector.start()

        while not self.terminate.is_set():
            event = self.context_event_collector.get_event(timeout=0.1)
            if event is None: continue
            if self.ignore_event(event): continue
            # TODO: filter events resulting from abstraction computation
            # TODO: filter events resulting from updating remote abstractions
            LOGGER.info('Processing Event({:s})...'.format(grpc_message_to_json_string(event)))
            dlt_records = self.update_abstraction(event)
            self.send_dlt_records(dlt_records)

        self.context_event_collector.stop()
        self.context_client.close()
        self.dlt_connector_client.close()

    def ignore_event(self, event : EventTypes) -> List[DltRecordIdTypes]:
        if self.own_context_id is None: return False
        own_context_uuid = self.own_context_id.context_uuid.uuid

        if isinstance(event, ContextEvent):
            context_uuid = event.context_id.context_uuid.uuid
            return context_uuid == own_context_uuid
        elif isinstance(event, TopologyEvent):
            context_uuid = event.topology_id.context_id.context_uuid.uuid
            if context_uuid != own_context_uuid: return True
            topology_uuid = event.topology_id.topology_uuid.uuid
            if topology_uuid in {DOMAINS_TOPOLOGY_UUID}: return True

        return False

    def send_dlt_records(self, dlt_records : Union[DltRecordIdTypes, List[DltRecordIdTypes]]) -> None:
        for dlt_record_id in dlt_records:
            if isinstance(dlt_record_id, DeviceId):
                self.dlt_connector_client.RecordDevice(dlt_record_id)
            elif isinstance(dlt_record_id, ServiceId):
                self.dlt_connector_client.RecordService(dlt_record_id)
            elif isinstance(dlt_record_id, SliceId):
                self.dlt_connector_client.RecordSlice(dlt_record_id)
            else:
                LOGGER.error('Unsupported Record({:s})'.format(str(dlt_record_id)))

    def _initialize_context_and_topologies(self) -> None:
        if self.own_context_id is not None: return

        own_domain_uuid = self.own_domain_filder.own_domain_uuid
        if own_domain_uuid is None: return

        # Find own domain UUID and own ContextId
        self.own_context_id = ContextId(**json_context_id(own_domain_uuid))

        # If "admin" context does not exist, create it; should exist
        #if DEFAULT_CONTEXT_UUID not in existing_context_uuids:
        #    self.context_client.SetContext(Context(**json_context(DEFAULT_CONTEXT_UUID)))
        #self.__admin_context_id = ContextId(**json_context_id(DEFAULT_CONTEXT_UUID))

        # Create topologies "admin", "domains", and "aggregated" within own context
        topology_uuids = [DEFAULT_TOPOLOGY_UUID, DOMAINS_TOPOLOGY_UUID]
        create_missing_topologies(self.context_client, self.own_context_id, topology_uuids)

    def update_abstraction(self, event : EventTypes) -> List[DltRecordIdTypes]:
        dlt_record_ids_with_changes = []
        changed = False

        # TODO: identify changes from event and update endpoints accordingly
        if event is None:
            # just initializing, do nothing
            pass

        elif isinstance(event, ContextEvent):
            context_id = event.context_id
            context_uuid = context_id.context_uuid.uuid
            if (context_uuid != DEFAULT_CONTEXT_UUID) and (self.own_context_id is None):
                self._initialize_context_and_topologies()

                own_domain_uuid = self.own_domain_filder.own_domain_uuid

                if self.abstract_topology.has_node(own_domain_uuid):
                    abstract_device = self.abstract_topology.nodes[own_domain_uuid]['obj']
                else:
                    abstract_device = AbstractDevice()
                    self.abstract_topology.add_node(own_domain_uuid, obj=abstract_device)

                # if already initialized, does nothing and returns False
                # if own context UUID cannot be identified, does nothing and returns None
                # if own context UUID be identified, initialized the abstract device and returns True
                _changed = abstract_device.initialize()
                if _changed is None: return dlt_record_ids_with_changes
                changed = changed or _changed
            else:
                LOGGER.warning('Ignoring Event({:s})'.format(grpc_message_to_json_string(event)))

        elif isinstance(event, TopologyEvent):
            topology_id = event.topology_id
            topology_uuid = topology_id.topology_uuid.uuid
            context_id = topology_id.context_id
            context_uuid = context_id.context_uuid.uuid
            if (context_uuid == self.own_domain_filder.own_domain_uuid) and (topology_uuid == DEFAULT_TOPOLOGY_UUID):
                topology = self.context_client.GetTopology(event.topology_id)
                for device_id in topology.device_ids:
                    device = self.context_client.GetDevice(device_id)
                    _changed = self.abstract_device.update_abstract_device_endpoints(device)
                    changed = changed or _changed
            else:
                LOGGER.warning('Ignoring Event({:s})'.format(grpc_message_to_json_string(event)))
            
        elif isinstance(event, DeviceEvent):
            admin_topology_device_uuids = get_uuids_of_devices_in_topology(
                self.context_client, self.own_context_id, DEFAULT_TOPOLOGY_UUID)
            device_uuid = event.device_id.device_uuid.uuid
            if device_uuid in admin_topology_device_uuids:
                device = self.context_client.GetDevice(event.device_id)
                _changed = self.abstract_device.update_abstract_device_endpoints(device)
                changed = changed or _changed
            else:
                LOGGER.warning('Ignoring Event({:s})'.format(grpc_message_to_json_string(event)))

        else:
            LOGGER.warning('Unsupported Event({:s})'.format(grpc_message_to_json_string(event)))

        if changed:
            self.context_client.SetDevice(self.abstract_device.own_abstract_device)
            dlt_record_ids_with_changes.append(self.abstract_device.own_abstract_device_id)

        return dlt_record_ids_with_changes
