# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method
from common.proto.e2eorchestrator_pb2 import E2EOrchestratorRequest, E2EOrchestratorReply
from common.proto.context_pb2 import Empty, Connection, EndPointId, Link, LinkId, TopologyDetails, TopologyId, Device, Topology, Context, Service
from common.proto.e2eorchestrator_pb2_grpc import E2EOrchestratorServiceServicer
from context.client.ContextClient import ContextClient
from service.client.ServiceClient import ServiceClient
from context.service.database.uuids.EndPoint import endpoint_get_uuid
from common.proto.vnt_manager_pb2 import VNTSubscriptionRequest
from common.tools.grpc.Tools import grpc_message_to_json_string
import grpc
import json
import logging
import networkx as nx
from threading import Thread
import time
from websockets.sync.client import connect
from websockets.sync.server import serve

LOGGER = logging.getLogger(__name__)
logging.getLogger("websockets").propagate = False

METRICS_POOL = MetricsPool("E2EOrchestrator", "RPC")


context_client: ContextClient = ContextClient()
service_client: ServiceClient = ServiceClient()

EXT_HOST = "nbiservice.tfs-ip.svc.cluster.local"
EXT_PORT = "8762"

OWN_HOST = "e2e-orchestratorservice.tfs-e2e.svc.cluster.local"
OWN_PORT = "8761"


def _event_received(websocket):
    for message in websocket:
        message_json = json.loads(message)

        if 'link_id' in message_json:
            link = Link(**message_json)

            service = Service()
            service.service_id = link.link_id.link_uuid
            service.serivice_type = 2 # Optical
            service.service_status = 1
            
            # service_client.CreateService(service)

            websocket.send(message)


        else:
            topology_details = TopologyDetails(**message_json)

            context_id = topology_details.topology_id.context_id
            context = Context()
            context.context_id.CopyFrom(context_id)
            context_client.SetContext(context)

            topology_id = topology_details.topology_id
            topology = Topology()
            topology.topology_id.CopyFrom(topology_id)
            context_client.SetTopology(topology)

            for device in topology_details.devices:
                LOGGER.info('Setting Device: {}'.format(device))
                context_client.SetDevice(device)

            for link in topology_details.links:
                LOGGER.info('Setting Link: {}'.format(link))
                context_client.SetLink(link)



def _check_policies(link):
    return True




def requestSubscription():
    url = "ws://" + EXT_HOST + ":" + EXT_PORT
    request = VNTSubscriptionRequest()
    request.host = OWN_HOST
    request.port = OWN_PORT
    LOGGER.debug("Trying to connect to {}".format(url))
    try: 
        websocket = connect(url)
    except Exception as ex:
        LOGGER.error('Error connecting to {}'.format(url))
    else:
        with websocket:
            LOGGER.debug("Connected to {}".format(url))
            send = grpc_message_to_json_string(request)
            websocket.send(send)
            LOGGER.debug("Sent: {}".format(send))
            try:
                message = websocket.recv()
                LOGGER.debug("Received message from WebSocket: {}".format(message))
            except Exception as ex:
                LOGGER.info('Exception receiving from WebSocket: {}'.format(ex))

        events_server()
        LOGGER.info('Subscription requested')


def events_server():
    all_hosts = "0.0.0.0"

    try:
        server = serve(_event_received, all_hosts, int(OWN_PORT))
    except Exception as ex:
        LOGGER.error('Error starting server on {}:{}'.format(all_hosts, OWN_PORT))
        LOGGER.error('Exception!: {}'.format(ex))
    with server:
        LOGGER.info("Running events server...: {}:{}".format(all_hosts, OWN_PORT))
        server.serve_forever()
        LOGGER.info("Exiting events server...")



class E2EOrchestratorServiceServicerImpl(E2EOrchestratorServiceServicer):
    def __init__(self):
        LOGGER.debug("Creating Servicer...")
        LOGGER.debug("Servicer Created")

        try:
            LOGGER.info("Requesting subscription")
            subscription_thread = Thread(target=requestSubscription)
            subscription_thread.start()
        except Exception as ex:
            LOGGER.info("Exception!: {}".format(ex))


        
    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def Compute(self, request: E2EOrchestratorRequest, context: grpc.ServicerContext) -> E2EOrchestratorReply:
        endpoints_ids = []
        for endpoint_id in request.service.service_endpoint_ids:
            endpoints_ids.append(endpoint_get_uuid(endpoint_id)[2])

        graph = nx.Graph()

        devices = context_client.ListDevices(Empty()).devices

        for device in devices:
            endpoints_uuids = [endpoint.endpoint_id.endpoint_uuid.uuid
                               for endpoint in device.device_endpoints]
            for ep in endpoints_uuids:
                graph.add_node(ep)

            for ep in endpoints_uuids:
                for ep_i in endpoints_uuids:
                    if ep == ep_i:
                        continue
                    graph.add_edge(ep, ep_i)

        links = context_client.ListLinks(Empty()).links
        for link in links:
            eps = []
            for endpoint_id in link.link_endpoint_ids:
                eps.append(endpoint_id.endpoint_uuid.uuid)
            graph.add_edge(eps[0], eps[1])


        shortest = nx.shortest_path(graph, endpoints_ids[0], endpoints_ids[1])

        path = E2EOrchestratorReply()
        path.services.append(copy.deepcopy(request.service))
        for i in range(0, int(len(shortest)/2)):
            conn = Connection()
            ep_a_uuid = str(shortest[i*2])
            ep_z_uuid = str(shortest[i*2+1])

            conn.connection_id.connection_uuid.uuid = str(ep_a_uuid) + '_->_' + str(ep_z_uuid)

            ep_a_id = EndPointId()
            ep_a_id.endpoint_uuid.uuid = ep_a_uuid
            conn.path_hops_endpoint_ids.append(ep_a_id)

            ep_z_id = EndPointId()
            ep_z_id.endpoint_uuid.uuid = ep_z_uuid
            conn.path_hops_endpoint_ids.append(ep_z_id)

            path.connections.append(conn)


