# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy, logging, os
from common.proto.context_pb2 import Context, ContextId, DeviceId, Link, LinkId, Topology, Device, TopologyId
from common.proto.pathcomp_pb2 import PathCompRequest
from common.tools.grpc.Tools import grpc_message_to_json
from common.tools.object_factory.Constraint import json_constraint
from context.client.ContextClient import ContextClient
from device.client.DeviceClient import DeviceClient
from pathcomp.frontend.client.PathCompClient import PathCompClient

# Scenarios:
#from .Objects_A_B_C import CONTEXTS, DEVICES, LINKS, OBJECTS_PER_TOPOLOGY, SERVICES, TOPOLOGIES
from .Objects_DC_CSGW_TN import CONTEXTS, DEVICES, LINKS, OBJECTS_PER_TOPOLOGY, SERVICES, TOPOLOGIES

# configure backend environment variables before overwriting them with fixtures to use real backend pathcomp
DEFAULT_PATHCOMP_BACKEND_SCHEME  = 'http'
DEFAULT_PATHCOMP_BACKEND_HOST    = '127.0.0.1'
DEFAULT_PATHCOMP_BACKEND_PORT    = '8081'
DEFAULT_PATHCOMP_BACKEND_BASEURL = '/pathComp/api/v1/compRoute'

os.environ['PATHCOMP_BACKEND_SCHEME'] = os.environ.get('PATHCOMP_BACKEND_SCHEME', DEFAULT_PATHCOMP_BACKEND_SCHEME)
os.environ['PATHCOMP_BACKEND_BASEURL'] = os.environ.get('PATHCOMP_BACKEND_BASEURL', DEFAULT_PATHCOMP_BACKEND_BASEURL)

# Find IP:port of backend container as follows:
# - first check env vars PATHCOMP_BACKEND_HOST & PATHCOMP_BACKEND_PORT
# - if not set, check env vars PATHCOMPSERVICE_SERVICE_HOST & PATHCOMPSERVICE_SERVICE_PORT_HTTP
# - if not set, use DEFAULT_PATHCOMP_BACKEND_HOST & DEFAULT_PATHCOMP_BACKEND_PORT
backend_host = DEFAULT_PATHCOMP_BACKEND_HOST
backend_host = os.environ.get('PATHCOMPSERVICE_SERVICE_HOST', backend_host)
os.environ['PATHCOMP_BACKEND_HOST'] = os.environ.get('PATHCOMP_BACKEND_HOST', backend_host)

backend_port = DEFAULT_PATHCOMP_BACKEND_PORT
backend_port = os.environ.get('PATHCOMPSERVICE_SERVICE_PORT_HTTP', backend_port)
os.environ['PATHCOMP_BACKEND_PORT'] = os.environ.get('PATHCOMP_BACKEND_PORT', backend_port)

from .PrepareTestScenario import ( # pylint: disable=unused-import
    # be careful, order of symbols is important here!
    mock_service, pathcomp_service, context_client, device_client, pathcomp_client)

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.DEBUG)

def test_prepare_environment(
    context_client : ContextClient, # pylint: disable=redefined-outer-name
    device_client : DeviceClient):  # pylint: disable=redefined-outer-name

    for context  in CONTEXTS  : context_client.SetContext (Context (**context ))
    for topology in TOPOLOGIES: context_client.SetTopology(Topology(**topology))
    for device   in DEVICES   : device_client .AddDevice  (Device  (**device  ))
    for link     in LINKS     : context_client.SetLink    (Link    (**link    ))

    for topology_id, device_ids, link_ids in OBJECTS_PER_TOPOLOGY:
        topology = Topology()
        topology.CopyFrom(context_client.GetTopology(TopologyId(**topology_id)))

        device_ids_in_topology = {device_id.device_uuid.uuid for device_id in topology.device_ids}
        func_device_id_not_added = lambda device_id: device_id['device_uuid']['uuid'] not in device_ids_in_topology
        func_device_id_json_to_grpc = lambda device_id: DeviceId(**device_id)
        device_ids_to_add = list(map(func_device_id_json_to_grpc, filter(func_device_id_not_added, device_ids)))
        topology.device_ids.extend(device_ids_to_add)

        link_ids_in_topology = {link_id.link_uuid.uuid for link_id in topology.link_ids}
        func_link_id_not_added = lambda link_id: link_id['link_uuid']['uuid'] not in link_ids_in_topology
        func_link_id_json_to_grpc = lambda link_id: LinkId(**link_id)
        link_ids_to_add = list(map(func_link_id_json_to_grpc, filter(func_link_id_not_added, link_ids)))
        topology.link_ids.extend(link_ids_to_add)

        context_client.SetTopology(topology)

def test_request_service_shortestpath(
    pathcomp_client : PathCompClient):  # pylint: disable=redefined-outer-name

    request_services = copy.deepcopy(SERVICES)
    #request_services[0]['service_constraints'] = [
    #    json_constraint('bandwidth[gbps]', 1000.0),
    #    json_constraint('latency[ms]',     1200.0),
    #]
    pathcomp_request = PathCompRequest(services=request_services)
    pathcomp_request.shortest_path.Clear()  # hack to select the shortest path algorithm that has no attributes

    pathcomp_reply = pathcomp_client.Compute(pathcomp_request)

    pathcomp_reply = grpc_message_to_json(pathcomp_reply)
    reply_services = pathcomp_reply['services']
    reply_connections = pathcomp_reply['connections']
    assert len(request_services) <= len(reply_services)
    request_service_ids = {
        '{:s}/{:s}'.format(
            svc['service_id']['context_id']['context_uuid']['uuid'],
            svc['service_id']['service_uuid']['uuid']
        )
        for svc in request_services
    }
    reply_service_ids = {
        '{:s}/{:s}'.format(
            svc['service_id']['context_id']['context_uuid']['uuid'],
            svc['service_id']['service_uuid']['uuid']
        )
        for svc in reply_services
    }
    # Assert all requested services have a reply
    # It permits having other services not requested (i.e., sub-services)
    assert len(request_service_ids.difference(reply_service_ids)) == 0

    reply_connection_service_ids = {
        '{:s}/{:s}'.format(
            conn['service_id']['context_id']['context_uuid']['uuid'],
            conn['service_id']['service_uuid']['uuid']
        )
        for conn in reply_connections
    }
    # Assert all requested services have a connection associated
    # It permits having other connections not requested (i.e., connections for sub-services)
    assert len(request_service_ids.difference(reply_connection_service_ids)) == 0

    # TODO: implement other checks. examples:
    # - request service and reply service endpoints match
    # - request service and reply connection endpoints match
    # - reply sub-service and reply sub-connection endpoints match
    # - others?
    #for json_service,json_connection in zip(json_services, json_connections):


def test_request_service_kshortestpath(
    pathcomp_client : PathCompClient):  # pylint: disable=redefined-outer-name

    request_services = SERVICES
    pathcomp_request = PathCompRequest(services=request_services)
    pathcomp_request.k_shortest_path.k_inspection = 2   #pylint: disable=no-member
    pathcomp_request.k_shortest_path.k_return = 2       #pylint: disable=no-member

    pathcomp_reply = pathcomp_client.Compute(pathcomp_request)

    pathcomp_reply = grpc_message_to_json(pathcomp_reply)
    reply_services = pathcomp_reply['services']
    reply_connections = pathcomp_reply['connections']
    assert len(request_services) <= len(reply_services)
    request_service_ids = {
        '{:s}/{:s}'.format(
            svc['service_id']['context_id']['context_uuid']['uuid'],
            svc['service_id']['service_uuid']['uuid']
        )
        for svc in request_services
    }
    reply_service_ids = {
        '{:s}/{:s}'.format(
            svc['service_id']['context_id']['context_uuid']['uuid'],
            svc['service_id']['service_uuid']['uuid']
        )
        for svc in reply_services
    }
    # Assert all requested services have a reply
    # It permits having other services not requested (i.e., sub-services)
    assert len(request_service_ids.difference(reply_service_ids)) == 0

    reply_connection_service_ids = {
        '{:s}/{:s}'.format(
            conn['service_id']['context_id']['context_uuid']['uuid'],
            conn['service_id']['service_uuid']['uuid']
        )
        for conn in reply_connections
    }
    # Assert all requested services have a connection associated
    # It permits having other connections not requested (i.e., connections for sub-services)
    assert len(request_service_ids.difference(reply_connection_service_ids)) == 0

    # TODO: implement other checks. examples:
    # - request service and reply service endpoints match
    # - request service and reply connection endpoints match
    # - reply sub-service and reply sub-connection endpoints match
    # - others?
    #for json_service,json_connection in zip(json_services, json_connections):


def test_request_service_kdisjointpath(
    pathcomp_client : PathCompClient):  # pylint: disable=redefined-outer-name

    request_services = SERVICES
    pathcomp_request = PathCompRequest(services=request_services)
    pathcomp_request.k_disjoint_path.num_disjoint = 2   #pylint: disable=no-member

    pathcomp_reply = pathcomp_client.Compute(pathcomp_request)

    pathcomp_reply = grpc_message_to_json(pathcomp_reply)
    reply_services = pathcomp_reply['services']
    reply_connections = pathcomp_reply['connections']
    assert len(request_services) <= len(reply_services)
    request_service_ids = {
        '{:s}/{:s}'.format(
            svc['service_id']['context_id']['context_uuid']['uuid'],
            svc['service_id']['service_uuid']['uuid']
        )
        for svc in request_services
    }
    reply_service_ids = {
        '{:s}/{:s}'.format(
            svc['service_id']['context_id']['context_uuid']['uuid'],
            svc['service_id']['service_uuid']['uuid']
        )
        for svc in reply_services
    }
    # Assert all requested services have a reply
    # It permits having other services not requested (i.e., sub-services)
    assert len(request_service_ids.difference(reply_service_ids)) == 0

    reply_connection_service_ids = {
        '{:s}/{:s}'.format(
            conn['service_id']['context_id']['context_uuid']['uuid'],
            conn['service_id']['service_uuid']['uuid']
        )
        for conn in reply_connections
    }
    # Assert all requested services have a connection associated
    # It permits having other connections not requested (i.e., connections for sub-services)
    assert len(request_service_ids.difference(reply_connection_service_ids)) == 0

    # TODO: implement other checks. examples:
    # - request service and reply service endpoints match
    # - request service and reply connection endpoints match
    # - reply sub-service and reply sub-connection endpoints match
    # - others?
    #for json_service,json_connection in zip(json_services, json_connections):


def test_cleanup_environment(
    context_client : ContextClient, # pylint: disable=redefined-outer-name
    device_client : DeviceClient):  # pylint: disable=redefined-outer-name

    for link     in LINKS     : context_client.RemoveLink    (LinkId    (**link    ['link_id'    ]))
    for device   in DEVICES   : device_client .DeleteDevice  (DeviceId  (**device  ['device_id'  ]))
    for topology in TOPOLOGIES: context_client.RemoveTopology(TopologyId(**topology['topology_id']))
    for context  in CONTEXTS  : context_client.RemoveContext (ContextId (**context ['context_id' ]))