# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, List, Set, Tuple
from common.Constants import DEFAULT_CONTEXT_UUID, INTERDOMAIN_TOPOLOGY_UUID
from common.DeviceTypes import DeviceTypeEnum
from common.proto.context_pb2 import ContextId, Device, Empty, EndPointId, ServiceTypeEnum, Slice
from common.proto.pathcomp_pb2 import PathCompRequest
from common.tools.context_queries.CheckType import device_type_is_network
from common.tools.context_queries.Device import get_devices_in_topology, get_uuids_of_devices_in_topology
from common.tools.grpc.Tools import grpc_message_to_json_string
from common.tools.object_factory.Context import json_context_id
from context.client.ContextClient import ContextClient
from pathcomp.frontend.client.PathCompClient import PathCompClient

LOGGER = logging.getLogger(__name__)

ADMIN_CONTEXT_ID = ContextId(**json_context_id(DEFAULT_CONTEXT_UUID))
DATACENTER_DEVICE_TYPES = {DeviceTypeEnum.DATACENTER, DeviceTypeEnum.EMULATED_DATACENTER}

def get_local_device_uuids(context_client : ContextClient) -> Set[str]:
    topologies = context_client.ListTopologies(ADMIN_CONTEXT_ID)
    topologies = {topology.topology_id.topology_uuid.uuid : topology for topology in topologies.topologies}

    local_topology_uuids = set(topologies.keys())
    local_topology_uuids.discard(INTERDOMAIN_TOPOLOGY_UUID)

    local_device_uuids = set()
    for local_topology_uuid in local_topology_uuids:
        topology_device_ids = topologies[local_topology_uuid].device_ids
        topology_device_uuids = {device_id.device_uuid.uuid for device_id in topology_device_ids}
        local_device_uuids.update(topology_device_uuids)

    return local_device_uuids

def get_local_domain_devices(context_client : ContextClient) -> List[Device]:
    local_device_uuids = get_local_device_uuids(context_client)
    all_devices = context_client.ListDevices(Empty())
    local_domain_devices = list()
    for device in all_devices.devices:
        if not device_type_is_network(device.device_type): continue
        device_uuid = device.device_id.device_uuid.uuid
        if device_uuid not in local_device_uuids: continue
        local_domain_devices.append(device)
    return local_domain_devices

def is_multi_domain(context_client : ContextClient, endpoint_ids : List[EndPointId]) -> bool:
    local_device_uuids = get_local_device_uuids(context_client)
    remote_endpoint_ids = [
        endpoint_id
        for endpoint_id in endpoint_ids
        if endpoint_id.device_id.device_uuid.uuid not in local_device_uuids
    ]
    LOGGER.info('remote_endpoint_ids = {:s}'.format(str(remote_endpoint_ids)))
    is_multi_domain_ = len(remote_endpoint_ids) > 0
    LOGGER.info('is_multi_domain = {:s}'.format(str(is_multi_domain_)))
    return is_multi_domain_

def compute_interdomain_path(
    pathcomp_client : PathCompClient, slice_ : Slice
) -> List[Tuple[str, List[EndPointId]]]:
    context_uuid = slice_.slice_id.context_id.context_uuid.uuid
    slice_uuid = slice_.slice_id.slice_uuid.uuid

    pathcomp_req = PathCompRequest()
    pathcomp_req.shortest_path.Clear()                                          # pylint: disable=no-member
    pathcomp_req_svc = pathcomp_req.services.add()                              # pylint: disable=no-member
    pathcomp_req_svc.service_id.context_id.context_uuid.uuid = context_uuid
    pathcomp_req_svc.service_id.service_uuid.uuid = slice_uuid
    pathcomp_req_svc.service_type = ServiceTypeEnum.SERVICETYPE_L2NM

    for endpoint_id in slice_.slice_endpoint_ids:
        service_endpoint_id = pathcomp_req_svc.service_endpoint_ids.add()
        service_endpoint_id.CopyFrom(endpoint_id)
    
    constraint_bw = pathcomp_req_svc.service_constraints.add()
    constraint_bw.custom.constraint_type = 'bandwidth[gbps]'
    constraint_bw.custom.constraint_value = '10.0'

    constraint_lat = pathcomp_req_svc.service_constraints.add()
    constraint_lat.custom.constraint_type = 'latency[ms]'
    constraint_lat.custom.constraint_value = '100.0'

    LOGGER.info('pathcomp_req = {:s}'.format(grpc_message_to_json_string(pathcomp_req)))
    pathcomp_rep = pathcomp_client.Compute(pathcomp_req)
    LOGGER.info('pathcomp_rep = {:s}'.format(grpc_message_to_json_string(pathcomp_rep)))

    service = next(iter([
        service
        for service in pathcomp_rep.services
        if service.service_id == pathcomp_req_svc.service_id
    ]), None)
    if service is None:
        str_service_id = grpc_message_to_json_string(pathcomp_req_svc.service_id)
        raise Exception('Service({:s}) not found'.format(str_service_id))

    connection = next(iter([
        connection
        for connection in pathcomp_rep.connections
        if connection.service_id == pathcomp_req_svc.service_id
    ]), None)
    if connection is None:
        str_service_id = grpc_message_to_json_string(pathcomp_req_svc.service_id)
        raise Exception('Connection for Service({:s}) not found'.format(str_service_id))

    domain_list : List[str] = list()
    domain_to_endpoint_ids : Dict[str, List[EndPointId]] = dict()
    for endpoint_id in connection.path_hops_endpoint_ids:
        device_uuid = endpoint_id.device_id.device_uuid.uuid
        #endpoint_uuid = endpoint_id.endpoint_uuid.uuid
        if device_uuid not in domain_to_endpoint_ids: domain_list.append(device_uuid)
        domain_to_endpoint_ids.setdefault(device_uuid, []).append(endpoint_id)

    return [
        (domain_uuid, domain_to_endpoint_ids.get(domain_uuid))
        for domain_uuid in domain_list
    ]

def compute_traversed_domains(
    context_client : ContextClient, interdomain_path : List[Tuple[str, List[EndPointId]]]
) -> List[Tuple[str, Device, bool, List[EndPointId]]]:

    local_device_uuids = get_local_device_uuids(context_client)
    interdomain_devices = get_devices_in_topology(context_client, ADMIN_CONTEXT_ID, INTERDOMAIN_TOPOLOGY_UUID)
    interdomain_devices = {
        device.device_id.device_uuid.uuid : device
        for device in interdomain_devices
    }

    traversed_domains : List[Tuple[str, Device, bool, List[EndPointId]]] = list()
    for device_uuid, endpoint_ids in interdomain_path:
        abstract_device = interdomain_devices[device_uuid]
        if abstract_device.device_type in DATACENTER_DEVICE_TYPES: continue
        is_local_domain = device_uuid not in local_device_uuids
        domain = (device_uuid, abstract_device, is_local_domain, endpoint_ids)
        traversed_domains.append(domain)

    return traversed_domains
