# Copyright 2022-2024 ETSI SDG TeraFlowSDN (TFS) (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grpc, json, logging, uuid
from confluent_kafka import Consumer as KafkaConsumer
from confluent_kafka import Producer as KafkaProducer
from confluent_kafka import KafkaError
from common.method_wrappers.Decorator import MetricsPool, safe_and_metered_rpc_method
from common.proto.context_pb2 import Empty, Link, LinkId, LinkList
from common.proto.vnt_manager_pb2_grpc import VNTManagerServiceServicer
from common.tools.grpc.Tools import grpc_message_to_json_string
from common.tools.kafka.Variables import KafkaConfig, KafkaTopic
from context.client.ContextClient import ContextClient
from .vntm_config_device import configure, deconfigure


LOGGER = logging.getLogger(__name__)

METRICS_POOL = MetricsPool("VNTManager", "RPC")


class VNTManagerServiceServicerImpl(VNTManagerServiceServicer):
    def __init__(self):
        LOGGER.debug("Creating Servicer...")
        self.context_client = ContextClient()
        self.links = []
        LOGGER.debug("Servicer Created")

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def ListVirtualLinks(self, request : Empty, context : grpc.ServicerContext) -> LinkList:
        links = self.context_client.ListLinks(Empty()).links
        return [link for link in links if link.virtual]

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def GetVirtualLink(self, request : LinkId, context : grpc.ServicerContext) -> Link:
        link = self.context_client.GetLink(request)
        return link if link.virtual else Empty()

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def SetVirtualLink(self, request : Link, context : grpc.ServicerContext) -> LinkId:
        try:
            LOGGER.info('[SetVirtualLink] request={:s}'.format(grpc_message_to_json_string(request)))
            request_key = str(uuid.uuid4())
            kafka_producer = KafkaProducer({
                'bootstrap.servers' : KafkaConfig.get_kafka_address()
            })

            vntm_request = json.dumps({
                'event': 'vlink_create', 'data': grpc_message_to_json_string(request)
            }).encode('utf-8')
            LOGGER.info('[SetVirtualLink] vntm_request={:s}'.format(str(vntm_request)))
            kafka_producer.produce(
                KafkaTopic.VNTMANAGER_REQUEST.value, key=request_key, value=vntm_request
            )
            kafka_producer.flush()

            kafka_consumer = KafkaConsumer({
                'bootstrap.servers' : KafkaConfig.get_kafka_address(),
                'group.id'          : str(uuid.uuid4()),
                'auto.offset.reset' : 'latest'
            })
            kafka_consumer.subscribe([KafkaTopic.VNTMANAGER_RESPONSE.value])
            while True:
                receive_msg = kafka_consumer.poll(2.0)
                if receive_msg is None: continue
                LOGGER.info('[SetVirtualLink] receive_msg={:s}'.format(str(receive_msg)))
                if receive_msg.error():
                    if receive_msg.error().code() == KafkaError._PARTITION_EOF: continue
                    LOGGER.error('Consumer error: {:s}'.format(str(receive_msg.error())))
                    break
                reply_key = receive_msg.key().decode('utf-8')
                if reply_key == request_key: break

            link = Link(**json.loads(receive_msg.value().decode('utf-8')))
            # at this point, we know the request was accepted and an optical connection was created

            # configure('CSGW1', 'xe5', 'CSGW2', 'xe5', 'ecoc2024-1')
            self.context_client.SetLink(link)
        except: # pylint: disable=bare-except
            MSG = 'Exception setting virtual link={:s}'
            LOGGER.exception(MSG.format(str(request.link_id.link_uuid.uuid)))
        return request.link_id

    @safe_and_metered_rpc_method(METRICS_POOL, LOGGER)
    def RemoveVirtualLink(self, request : LinkId, context : grpc.ServicerContext) -> Empty:
        try:
            LOGGER.debug('Removing virtual link')
            request_key = str(uuid.uuid4())

            kafka_producer = KafkaProducer({
                'bootstrap.servers' : KafkaConfig.get_kafka_address()
            })

            vntm_request = json.dumps({
                'event': 'vlink_remove', 'data': grpc_message_to_json_string(request)
            }).encode('utf-8')
            LOGGER.info('[RemoveVirtualLink] vntm_request={:s}'.format(str(vntm_request)))
            kafka_producer.produce(
                KafkaTopic.VNTMANAGER_REQUEST.value, key=request_key, value=vntm_request
            )
            kafka_producer.flush()

            kafka_consumer = KafkaConsumer({
                'bootstrap.servers' : KafkaConfig.get_kafka_address(),
                'group.id'          : str(uuid.uuid4()),
                'auto.offset.reset' : 'latest'
            })
            kafka_consumer.subscribe([KafkaTopic.VNTMANAGER_RESPONSE.value])
            while True:
                receive_msg = kafka_consumer.poll(2.0)
                if receive_msg is None: continue
                if receive_msg.error():
                    if receive_msg.error().code() == KafkaError._PARTITION_EOF: continue
                    LOGGER.error('Consumer error: {:s}'.format(str(receive_msg.error())))
                    break
                reply_key = receive_msg.key().decode('utf-8')
                if reply_key == request_key: break

            link_id = LinkId(**json.loads(receive_msg.value().decode('utf-8')))
            # at this point, we know the request was accepted and an optical connection was deleted

            # deconfigure('CSGW1', 'xe5', 'CSGW2', 'xe5', 'ecoc2024-1')
            self.context_client.RemoveLink(link_id)
            LOGGER.info('Removed')
        except: # pylint: disable=bare-except
            MSG = 'Exception removing virtual link={:s}'
            LOGGER.exception(MSG.format(str(request.link_uuid.uuid)))

        return Empty()
