Skip to content
Snippets Groups Projects
OpenConfigDriver.py 16.2 KiB
Newer Older
# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import anytree, logging, pytz, queue, re, threading
import lxml.etree as ET
from datetime import datetime, timedelta
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from apscheduler.executors.pool import ThreadPoolExecutor
from apscheduler.job import Job
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.schedulers.background import BackgroundScheduler
from netconf_client.connect import connect_ssh
from netconf_client.ncclient import Manager
from common.type_checkers.Checkers import chk_length, chk_string, chk_type, chk_float
from device.service.driver_api.Exceptions import UnsupportedResourceKeyException
from device.service.driver_api._Driver import _Driver
from device.service.driver_api.AnyTreeTools import TreeNode, dump_subtree, get_subnode, set_subnode_value
from device.service.drivers.openconfig.Tools import xml_pretty_print, xml_to_dict, xml_to_file
from device.service.drivers.openconfig.templates import ALL_RESOURCE_KEYS, compose_config, get_filter, parse

DEBUG_MODE = False
#logging.getLogger('ncclient.transport.ssh').setLevel(logging.DEBUG if DEBUG_MODE else logging.WARNING)
logging.getLogger('apscheduler.executors.default').setLevel(logging.INFO if DEBUG_MODE else logging.ERROR)
logging.getLogger('apscheduler.scheduler').setLevel(logging.INFO if DEBUG_MODE else logging.ERROR)
logging.getLogger('monitoring-client').setLevel(logging.INFO if DEBUG_MODE else logging.ERROR)

LOGGER = logging.getLogger(__name__)

RE_GET_ENDPOINT_FROM_INTERFACE_KEY = re.compile(r'.*interface\[([^\]]+)\].*')
RE_GET_ENDPOINT_FROM_INTERFACE_XPATH = re.compile(r".*interface\[oci\:name\='([^\]]+)'\].*")

# Collection of samples through NetConf is very slow and each request collects all the data.
# Populate a cache periodically (when first interface is interrogated).
# Evict data after some seconds, when data is considered as outdated

SAMPLE_EVICTION_SECONDS = 30.0 # seconds
SAMPLE_RESOURCE_KEY = 'interfaces/interface/state/counters'

class SamplesCache:
    def __init__(self) -> None:
        self.__lock = threading.Lock()
        self.__timestamp = None
        self.__samples = {}

    def _refresh_samples(self, netconf_manager : Manager) -> None:
        with self.__lock:
            try:
                now = datetime.timestamp(datetime.utcnow())
                if self.__timestamp is not None and (now - self.__timestamp) < SAMPLE_EVICTION_SECONDS: return
                str_filter = get_filter(SAMPLE_RESOURCE_KEY)
                xml_data = netconf_manager.get(filter=str_filter).data_ele
                interface_samples = parse(SAMPLE_RESOURCE_KEY, xml_data)
                for interface,samples in interface_samples:
                    match = RE_GET_ENDPOINT_FROM_INTERFACE_KEY.match(interface)
                    if match is None: continue
                    interface = match.group(1)
                    self.__samples[interface] = samples
                self.__timestamp = now
            except: # pylint: disable=bare-except
                LOGGER.exception('Error collecting samples')

    def get(self, resource_key : str, netconf_manager : Manager) -> Tuple[float, Dict]:
        self._refresh_samples(netconf_manager)
        match = RE_GET_ENDPOINT_FROM_INTERFACE_XPATH.match(resource_key)
        with self.__lock:
            if match is None: return self.__timestamp, {}
            interface = match.group(1)
            return self.__timestamp, self.__samples.get(interface, {})

def do_sampling(
    netconf_manager : Manager, samples_cache : SamplesCache, resource_key : str, out_samples : queue.Queue
) -> None:
    try:
        timestamp, samples = samples_cache.get(resource_key, netconf_manager)
        counter_name = resource_key.split('/')[-1].split(':')[-1]
        value = samples.get(counter_name)
        if value is None:
            LOGGER.warning('[do_sampling] value not found for {:s}'.format(resource_key))
            return
        sample = (timestamp, resource_key, value)
        out_samples.put_nowait(sample)
    except: # pylint: disable=bare-except
        LOGGER.exception('Error retrieving samples')


class OpenConfigDriver(_Driver):
    def __init__(self, address : str, port : int, **settings) -> None: # pylint: disable=super-init-not-called
        self.__address = address
        self.__port = int(port)
        self.__settings = settings
        self.__lock = threading.Lock()
        #self.__initial = TreeNode('.')
        #self.__running = TreeNode('.')
        self.__subscriptions = TreeNode('.')
        self.__started = threading.Event()
        self.__terminate = threading.Event()
        self.__netconf_manager : Manager = None
        self.__scheduler = BackgroundScheduler(daemon=True) # scheduler used to emulate sampling events
        self.__scheduler.configure(
            jobstores = {'default': MemoryJobStore()},
            executors = {'default': ThreadPoolExecutor(max_workers=1)},
            job_defaults = {'coalesce': False, 'max_instances': 3},
            timezone=pytz.utc)
        self.__out_samples = queue.Queue()
        self.__samples_cache = SamplesCache()

    def Connect(self) -> bool:
        with self.__lock:
            if self.__started.is_set(): return True
            username = self.__settings.get('username')
            password = self.__settings.get('password')
            timeout = int(self.__settings.get('timeout', 120))
            session = connect_ssh(
                host=self.__address, port=self.__port, username=username, password=password)
            self.__netconf_manager = Manager(session, timeout=timeout)
            self.__netconf_manager.set_logger_level(logging.DEBUG if DEBUG_MODE else logging.WARNING)
            # Connect triggers activation of sampling events that will be scheduled based on subscriptions
            self.__scheduler.start()
            self.__started.set()
            return True

    def Disconnect(self) -> bool:
        with self.__lock:
            # Trigger termination of loops and processes
            self.__terminate.set()
            # If not started, assume it is already disconnected
            if not self.__started.is_set(): return True
            # Disconnect triggers deactivation of sampling events
            self.__scheduler.shutdown()
            self.__netconf_manager.close_session()
            return True

    def GetInitialConfig(self) -> List[Tuple[str, Any]]:
        with self.__lock:
            return []

    def GetConfig(self, resource_keys : List[str] = []) -> List[Tuple[str, Union[Any, None, Exception]]]:
        chk_type('resources', resource_keys, list)
        results = []
        with self.__lock:
            if len(resource_keys) == 0: resource_keys = ALL_RESOURCE_KEYS
            for i,resource_key in enumerate(resource_keys):
                str_resource_name = 'resource_key[#{:d}]'.format(i)
                try:
                    chk_string(str_resource_name, resource_key, allow_empty=False)
                    str_filter = get_filter(resource_key)
                    if str_filter is None: str_filter = resource_key
                    xml_data = self.__netconf_manager.get(filter=str_filter).data_ele
                    if isinstance(xml_data, Exception): raise xml_data
                    results.extend(parse(resource_key, xml_data))
                except Exception as e: # pylint: disable=broad-except
                    LOGGER.exception('Exception retrieving {:s}: {:s}'.format(str_resource_name, str(resource_key)))
                    results.append((resource_key, e)) # if validation fails, store the exception
        return results

    def SetConfig(self, resources : List[Tuple[str, Any]]) -> List[Union[bool, Exception]]:
        chk_type('resources', resources, list)
        if len(resources) == 0: return []
        results = []
        LOGGER.info('[SetConfig] resources = {:s}'.format(str(resources)))
        with self.__lock:
            for i,resource in enumerate(resources):
                str_resource_name = 'resources[#{:d}]'.format(i)
                try:
                    LOGGER.info('[SetConfig] resource = {:s}'.format(str(resource)))
                    chk_type(str_resource_name, resource, (list, tuple))
                    chk_length(str_resource_name, resource, min_length=2, max_length=2)
                    resource_key,resource_value = resource
                    chk_string(str_resource_name + '.key', resource_key, allow_empty=False)
                    str_config_message = compose_config(resource_key, resource_value)
                    if str_config_message is None: raise UnsupportedResourceKeyException(resource_key)
                    LOGGER.info('[SetConfig] str_config_message = {:s}'.format(str(str_config_message)))
                    self.__netconf_manager.edit_config(str_config_message, target='running')
                    results.append(True)
                except Exception as e: # pylint: disable=broad-except
                    LOGGER.exception('Exception setting {:s}: {:s}'.format(str_resource_name, str(resource)))
                    results.append(e) # if validation fails, store the exception
        return results

    def DeleteConfig(self, resources : List[Tuple[str, Any]]) -> List[Union[bool, Exception]]:
        chk_type('resources', resources, list)
        if len(resources) == 0: return []
        results = []
        LOGGER.info('[DeleteConfig] resources = {:s}'.format(str(resources)))
        with self.__lock:
            for i,resource in enumerate(resources):
                str_resource_name = 'resources[#{:d}]'.format(i)
                try:
                    LOGGER.info('[DeleteConfig] resource = {:s}'.format(str(resource)))
                    chk_type(str_resource_name, resource, (list, tuple))
                    chk_length(str_resource_name, resource, min_length=2, max_length=2)
                    resource_key,resource_value = resource
                    chk_string(str_resource_name + '.key', resource_key, allow_empty=False)
                    str_config_message = compose_config(resource_key, resource_value, delete=True)
                    if str_config_message is None: raise UnsupportedResourceKeyException(resource_key)
                    LOGGER.info('[DeleteConfig] str_config_message = {:s}'.format(str(str_config_message)))
                    self.__netconf_manager.edit_config(str_config_message, target='running')
                    results.append(True)
                except Exception as e: # pylint: disable=broad-except
                    LOGGER.exception('Exception deleting {:s}: {:s}'.format(str_resource_name, str(resource_key)))
                    results.append(e) # if validation fails, store the exception
        return results

    def SubscribeState(self, subscriptions : List[Tuple[str, float, float]]) -> List[Union[bool, Exception]]:
        chk_type('subscriptions', subscriptions, list)
        if len(subscriptions) == 0: return []
        results = []
        resolver = anytree.Resolver(pathattr='name')
        with self.__lock:
            for i,subscription in enumerate(subscriptions):
                str_subscription_name = 'subscriptions[#{:d}]'.format(i)
                try:
                    chk_type(str_subscription_name, subscription, (list, tuple))
                    chk_length(str_subscription_name, subscription, min_length=3, max_length=3)
                    resource_key,sampling_duration,sampling_interval = subscription
                    chk_string(str_subscription_name + '.resource_key', resource_key, allow_empty=False)
                    resource_path = resource_key.split('/')
                    chk_float(str_subscription_name + '.sampling_duration', sampling_duration, min_value=0)
                    chk_float(str_subscription_name + '.sampling_interval', sampling_interval, min_value=0)
                except Exception as e: # pylint: disable=broad-except
                    LOGGER.exception('Exception validating {:s}: {:s}'.format(str_subscription_name, str(resource_key)))
                    results.append(e) # if validation fails, store the exception
                    continue

                start_date,end_date = None,None
                if sampling_duration <= 1.e-12:
                    start_date = datetime.utcnow()
                    end_date = start_date + timedelta(seconds=sampling_duration)

                job_id = 'k={:s}/d={:f}/i={:f}'.format(resource_key, sampling_duration, sampling_interval)
                job = self.__scheduler.add_job(
                    do_sampling, args=(self.__netconf_manager, self.__samples_cache, resource_key, self.__out_samples),
                    kwargs={}, id=job_id, trigger='interval', seconds=sampling_interval,
                    start_date=start_date, end_date=end_date, timezone=pytz.utc)

                subscription_path = resource_path + ['{:.3f}:{:.3f}'.format(sampling_duration, sampling_interval)]
                set_subnode_value(resolver, self.__subscriptions, subscription_path, job)
                results.append(True)
        return results

    def UnsubscribeState(self, subscriptions : List[Tuple[str, float, float]]) -> List[Union[bool, Exception]]:
        chk_type('subscriptions', subscriptions, list)
        if len(subscriptions) == 0: return []
        results = []
        resolver = anytree.Resolver(pathattr='name')
        with self.__lock:
            for i,resource in enumerate(subscriptions):
                str_subscription_name = 'resources[#{:d}]'.format(i)
                try:
                    chk_type(str_subscription_name, resource, (list, tuple))
                    chk_length(str_subscription_name, resource, min_length=3, max_length=3)
                    resource_key,sampling_duration,sampling_interval = resource
                    chk_string(str_subscription_name + '.resource_key', resource_key, allow_empty=False)
                    resource_path = resource_key.split('/')
                    chk_float(str_subscription_name + '.sampling_duration', sampling_duration, min_value=0)
                    chk_float(str_subscription_name + '.sampling_interval', sampling_interval, min_value=0)
                except Exception as e: # pylint: disable=broad-except
                    LOGGER.exception('Exception validating {:s}: {:s}'.format(str_subscription_name, str(resource_key)))
                    results.append(e) # if validation fails, store the exception
                    continue

                subscription_path = resource_path + ['{:.3f}:{:.3f}'.format(sampling_duration, sampling_interval)]
                subscription_node = get_subnode(resolver, self.__subscriptions, subscription_path)

                # if not found, resource_node is None
                if subscription_node is None:
                    results.append(False)
                    continue

                job : Job = getattr(subscription_node, 'value', None)
                if job is None or not isinstance(job, Job):
                    raise Exception('Malformed subscription node or wrong resource key: {:s}'.format(str(resource)))
                job.remove()

                parent = subscription_node.parent
                children = list(parent.children)
                children.remove(subscription_node)
                parent.children = tuple(children)

                results.append(True)
        return results

    def GetState(self, blocking=False, terminate : Optional[threading.Event] = None) -> Iterator[Tuple[str, Any]]:
        while True:
            if self.__terminate.is_set(): break
            if terminate is not None and terminate.is_set(): break
            try:
                sample = self.__out_samples.get(block=blocking, timeout=0.1)
            except queue.Empty:
                if blocking: continue
                return
            if sample is None: continue
            yield sample