# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from common.proto.context_pb2 import Context, ContextId, Device, DeviceId, Empty, Link, LinkId, Topology, TopologyId
from context.client.ContextClient import ContextClient
from device.client.DeviceClient import DeviceClient
from .Fixtures import context_client, device_client
#from .Objects_BigNet import CONTEXT_ID, CONTEXTS, DEVICES, LINKS, TOPOLOGIES
#from .Objects_DC_CSGW_TN import CONTEXT_ID, CONTEXTS, DEVICES, LINKS, TOPOLOGIES, OBJECTS_PER_TOPOLOGY
#from .Objects_DC_CSGW_TN_OLS import CONTEXT_ID, CONTEXTS, DEVICES, LINKS, TOPOLOGIES, OBJECTS_PER_TOPOLOGY
from .Objects_DC_CSGW_OLS import CONTEXT_ID, CONTEXTS, DEVICES, LINKS, TOPOLOGIES, OBJECTS_PER_TOPOLOGY


LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.DEBUG)


def test_scenario_empty(context_client : ContextClient):  # pylint: disable=redefined-outer-name
    # ----- List entities - Ensure database is empty -------------------------------------------------------------------
    response = context_client.ListContexts(Empty())
    assert len(response.contexts) == 0

    response = context_client.ListDevices(Empty())
    assert len(response.devices) == 0

    response = context_client.ListLinks(Empty())
    assert len(response.links) == 0


def test_prepare_environment(
    context_client : ContextClient, # pylint: disable=redefined-outer-name
    device_client : DeviceClient):  # pylint: disable=redefined-outer-name

    for context  in CONTEXTS  : context_client.SetContext (Context (**context ))
    for topology in TOPOLOGIES: context_client.SetTopology(Topology(**topology))

    for device   in DEVICES   : device_client .AddDevice  (Device  (**device  ))
    for topology_id, device_ids, _ in OBJECTS_PER_TOPOLOGY:
        topology = Topology()
        topology.CopyFrom(context_client.GetTopology(TopologyId(**topology_id)))

        device_ids_in_topology = {device_id.device_uuid.uuid for device_id in topology.device_ids}
        func_device_id_not_added = lambda device_id: device_id['device_uuid']['uuid'] not in device_ids_in_topology
        func_device_id_json_to_grpc = lambda device_id: DeviceId(**device_id)
        device_ids_to_add = list(map(func_device_id_json_to_grpc, filter(func_device_id_not_added, device_ids)))
        topology.device_ids.extend(device_ids_to_add)

        context_client.SetTopology(topology)

    for link     in LINKS     : context_client.SetLink    (Link    (**link    ))
    for topology_id, _, link_ids in OBJECTS_PER_TOPOLOGY:
        topology = Topology()
        topology.CopyFrom(context_client.GetTopology(TopologyId(**topology_id)))

        link_ids_in_topology = {link_id.link_uuid.uuid for link_id in topology.link_ids}
        func_link_id_not_added = lambda link_id: link_id['link_uuid']['uuid'] not in link_ids_in_topology
        func_link_id_json_to_grpc = lambda link_id: LinkId(**link_id)
        link_ids_to_add = list(map(func_link_id_json_to_grpc, filter(func_link_id_not_added, link_ids)))
        topology.link_ids.extend(link_ids_to_add)

        context_client.SetTopology(topology)


def test_scenario_ready(context_client : ContextClient):  # pylint: disable=redefined-outer-name
    # ----- List entities - Ensure scenario is ready -------------------------------------------------------------------
    response = context_client.ListContexts(Empty())
    assert len(response.contexts) == len(CONTEXTS)

    response = context_client.ListTopologies(ContextId(**CONTEXT_ID))
    assert len(response.topologies) == len(TOPOLOGIES)

    response = context_client.ListDevices(Empty())
    assert len(response.devices) == len(DEVICES)

    response = context_client.ListLinks(Empty())
    assert len(response.links) == len(LINKS)

    response = context_client.ListServices(ContextId(**CONTEXT_ID))
    assert len(response.services) == 0
