# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Ref: https://github.com/openconfig/reference/blob/master/rpc/gnmi/gnmi-specification.md
# Ref: https://github.com/openconfig/gnmi/blob/master/proto/gnmi/gnmi.proto

import anytree, logging, queue, threading
from typing import Any, Iterator, List, Optional, Tuple, Union
from common.type_checkers.Checkers import chk_float, chk_length, chk_string, chk_type
from device.service.driver_api._Driver import (
    RESOURCE_ENDPOINTS, RESOURCE_INTERFACES, RESOURCE_NETWORK_INSTANCES,
    _Driver)

LOGGER = logging.getLogger(__name__)

SPECIAL_RESOURCE_MAPPINGS = {
    RESOURCE_ENDPOINTS        : '/endpoints',
    RESOURCE_INTERFACES       : '/interfaces',
    RESOURCE_NETWORK_INSTANCES: '/net-instances',
}

class MonitoringThread(threading.Thread):
    def __init__(self, in_subscriptions : queue.Queue, out_samples : queue.Queue) -> None:
        super().__init__(daemon=True)
        self._in_subscriptions = in_subscriptions
        self._out_samples = out_samples

    def run(self) -> None:
        while True:
            # TODO: req_iterator = generate_requests(self._in_subscriptions)
            # TODO: stub.Subscribe(req_iterator)
            self._out_samples.put_nowait((timestamp, resource_key, value))

class EmulatedDriver(_Driver):
    def __init__(self, address : str, port : int, **settings) -> None: # pylint: disable=super-init-not-called
        self.__lock = threading.Lock()

        # endpoints = settings.get('endpoints', [])

        self.__started = threading.Event()
        self.__terminate = threading.Event()

        self.__in_subscriptions = queue.Queue()
        self.__out_samples = queue.Queue()

        self.__monitoring_thread = MonitoringThread(self.__in_subscriptions, self.__out_samples)

    def Connect(self) -> bool:
        # If started, assume it is already connected
        if self.__started.is_set(): return True

        # TODO: check capabilities
        self.__monitoring_thread.start()

        # Indicate the driver is now connected to the device
        self.__started.set()
        return True

    def Disconnect(self) -> bool:
        # Trigger termination of loops and processes
        self.__terminate.set()

        # TODO: send unsubscriptions
        # TODO: terminate monitoring thread
        # TODO: disconnect gRPC
        self.__monitoring_thread.join()

        # If not started, assume it is already disconnected
        if not self.__started.is_set(): return True
        return True

    def GetInitialConfig(self) -> List[Tuple[str, Any]]:
        with self.__lock:
            return []

    def GetConfig(self, resource_keys : List[str] = []) -> List[Tuple[str, Union[Any, None, Exception]]]:
        chk_type('resources', resource_keys, list)
        with self.__lock:
            results = []
            for i,resource_key in enumerate(resource_keys):
                str_resource_name = 'resource_key[#{:d}]'.format(i)
                try:
                    chk_string(str_resource_name, resource_key, allow_empty=False)
                    resource_key = SPECIAL_RESOURCE_MAPPINGS.get(resource_key, resource_key)
                except Exception as e: # pylint: disable=broad-except
                    LOGGER.exception('Exception validating {:s}: {:s}'.format(str_resource_name, str(resource_key)))
                    results.append((resource_key, e)) # if validation fails, store the exception
                    continue

                # TODO: if resource_key == '/endpoints': retornar lista de endpoints
                #       results.extend(endpoints)
            return results

    def SubscribeState(self, subscriptions : List[Tuple[str, float, float]]) -> List[Union[bool, Exception]]:
        chk_type('subscriptions', subscriptions, list)
        if len(subscriptions) == 0: return []
        results = []
        with self.__lock:
            for i,subscription in enumerate(subscriptions):
                str_subscription_name = 'subscriptions[#{:d}]'.format(i)
                try:
                    chk_type(str_subscription_name, subscription, (list, tuple))
                    chk_length(str_subscription_name, subscription, min_length=3, max_length=3)
                    resource_key,sampling_duration,sampling_interval = subscription
                    chk_string(str_subscription_name + '.resource_key', resource_key, allow_empty=False)
                    resource_path = resource_key.split('/')
                    chk_float(str_subscription_name + '.sampling_duration', sampling_duration, min_value=0)
                    chk_float(str_subscription_name + '.sampling_interval', sampling_interval, min_value=0)
                except Exception as e: # pylint: disable=broad-except
                    LOGGER.exception('Exception validating {:s}: {:s}'.format(str_subscription_name, str(resource_key)))
                    results.append(e) # if validation fails, store the exception
                    continue

                # TODO: format subscription
                # TODO: self.__in_subscriptions.put_nowait(subscription)
                results.append(True)
        return results

    def UnsubscribeState(self, subscriptions : List[Tuple[str, float, float]]) -> List[Union[bool, Exception]]:
        chk_type('subscriptions', subscriptions, list)
        if len(subscriptions) == 0: return []
        results = []
        resolver = anytree.Resolver(pathattr='name')
        with self.__lock:
            for i,resource in enumerate(subscriptions):
                str_subscription_name = 'resources[#{:d}]'.format(i)
                try:
                    chk_type(str_subscription_name, resource, (list, tuple))
                    chk_length(str_subscription_name, resource, min_length=3, max_length=3)
                    resource_key,sampling_duration,sampling_interval = resource
                    chk_string(str_subscription_name + '.resource_key', resource_key, allow_empty=False)
                    resource_path = resource_key.split('/')
                    chk_float(str_subscription_name + '.sampling_duration', sampling_duration, min_value=0)
                    chk_float(str_subscription_name + '.sampling_interval', sampling_interval, min_value=0)
                except Exception as e: # pylint: disable=broad-except
                    LOGGER.exception('Exception validating {:s}: {:s}'.format(str_subscription_name, str(resource_key)))
                    results.append(e) # if validation fails, store the exception
                    continue

                # TODO: format unsubscription
                # TODO: self.__in_subscriptions.put_nowait(unsubscription)
                results.append(True)
        return results

    def GetState(self, blocking=False, terminate : Optional[threading.Event] = None) -> Iterator[Tuple[str, Any]]:
        while True:
            if self.__terminate.is_set(): break
            if terminate is not None and terminate.is_set(): break
            try:
                sample = self.__out_samples.get(block=blocking, timeout=0.1)
            except queue.Empty:
                if blocking: continue
                return
            if sample is None: continue
            yield sample
