# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time, random
from ctypes import Union
import json, logging
from typing import Dict
from flask import request
from flask.json import jsonify
from flask.wrappers import Response
from flask_restful import Resource
from werkzeug.exceptions import UnsupportedMediaType
from common.Settings import get_setting
from common.tools.grpc.Tools import grpc_message_to_json_string
from context.client.ContextClient import ContextClient
from context.proto.context_pb2 import ConfigActionEnum, Service, Slice
from service.client.ServiceClient import ServiceClient
from slice.client.SliceClient import SliceClient
from .schemas.site_network_access import SCHEMA_SITE_NETWORK_ACCESS
from .tools.Authentication import HTTP_AUTH
from .tools.ContextMethods import get_service, get_slice
from .tools.HttpStatusCodes import HTTP_NOCONTENT, HTTP_SERVERERROR
from .tools.Validator import validate_message
from .Constants import BEARER_MAPPINGS, DEFAULT_ADDRESS_FAMILIES, DEFAULT_BGP_AS, DEFAULT_BGP_ROUTE_TARGET, DEFAULT_MTU

LOGGER = logging.getLogger(__name__)

def process_site_network_access(context_client : ContextClient, site_network_access : Dict) -> Service:
    vpn_id = site_network_access['vpn-attachment']['vpn-id']
    cvlan_id = site_network_access['connection']['tagged-interface']['dot1q-vlan-tagged']['cvlan-id']
    bearer_reference = site_network_access['bearer']['bearer-reference']
    access_priority = site_network_access.get('availability', {}).get('access-priority')
    single_active = site_network_access.get('availability', {}).get('single-active')
    all_active = site_network_access.get('availability', {}).get('all-active')
    diversity_constraints = site_network_access.get('access-diversity', {}).get('constraints', {}).get('constraint', [])
    # TODO: manage targets of constraints, right now, only type of constraint is considered
    diversity_constraints = [constraint['constraint-type'] for constraint in diversity_constraints]

    mapping = BEARER_MAPPINGS.get(bearer_reference)
    if mapping is None:
        msg = 'Specified Bearer({:s}) is not configured.'
        raise Exception(msg.format(str(bearer_reference)))
    device_uuid,endpoint_uuid,router_id,route_distinguisher,sub_if_index,address_ip,address_prefix = mapping

    target : Union[Service, Slice, None] = None
    if target is None: target = get_service(context_client, vpn_id)
    if target is None: target = get_slice  (context_client, vpn_id)
    if target is None: raise Exception('VPN({:s}) not found in database'.format(str(vpn_id)))

    # pylint: disable=no-member
    endpoint_ids = target.service_endpoint_ids if isinstance(target, Service) else target.slice_endpoint_ids

    for endpoint_id in endpoint_ids:
        if endpoint_id.device_id.device_uuid.uuid != device_uuid: continue
        if endpoint_id.endpoint_uuid.uuid != endpoint_uuid: continue
        break   # found, do nothing
    else:
        # not found, add it
        endpoint_id = endpoint_ids.add()
        endpoint_id.device_id.device_uuid.uuid = device_uuid
        endpoint_id.endpoint_uuid.uuid = endpoint_uuid

    if isinstance(target, Slice): return target

    for config_rule in target.service_config.config_rules:                  # pylint: disable=no-member
        if config_rule.resource_key != '/settings': continue
        json_settings = json.loads(config_rule.resource_value)

        if 'mtu' not in json_settings:                                      # missing, add it
            json_settings['mtu'] = DEFAULT_MTU
        elif json_settings['mtu'] != DEFAULT_MTU:                           # differs, raise exception
            msg = 'Specified MTU({:s}) differs from Service MTU({:s})'
            raise Exception(msg.format(str(json_settings['mtu']), str(DEFAULT_MTU)))

        if 'address_families' not in json_settings:                         # missing, add it
            json_settings['address_families'] = DEFAULT_ADDRESS_FAMILIES
        elif json_settings['address_families'] != DEFAULT_ADDRESS_FAMILIES: # differs, raise exception
            msg = 'Specified AddressFamilies({:s}) differs from Service AddressFamilies({:s})'
            raise Exception(msg.format(str(json_settings['address_families']), str(DEFAULT_ADDRESS_FAMILIES)))

        if 'bgp_as' not in json_settings:                                   # missing, add it
            json_settings['bgp_as'] = DEFAULT_BGP_AS
        elif json_settings['bgp_as'] != DEFAULT_BGP_AS:                     # differs, raise exception
            msg = 'Specified BgpAs({:s}) differs from Service BgpAs({:s})'
            raise Exception(msg.format(str(json_settings['bgp_as']), str(DEFAULT_BGP_AS)))

        if 'bgp_route_target' not in json_settings:                         # missing, add it
            json_settings['bgp_route_target'] = DEFAULT_BGP_ROUTE_TARGET
        elif json_settings['bgp_route_target'] != DEFAULT_BGP_ROUTE_TARGET: # differs, raise exception
            msg = 'Specified BgpRouteTarget({:s}) differs from Service BgpRouteTarget({:s})'
            raise Exception(msg.format(str(json_settings['bgp_route_target']), str(DEFAULT_BGP_ROUTE_TARGET)))

        config_rule.resource_value = json.dumps(json_settings, sort_keys=True)
        break
    else:
        # not found, add it
        config_rule = target.service_config.config_rules.add()              # pylint: disable=no-member
        config_rule.action = ConfigActionEnum.CONFIGACTION_SET
        config_rule.resource_key = '/settings'
        config_rule.resource_value = json.dumps({
            'mtu'             : DEFAULT_MTU,
            'address_families': DEFAULT_ADDRESS_FAMILIES,
            'bgp_as'          : DEFAULT_BGP_AS,
            'bgp_route_target': DEFAULT_BGP_ROUTE_TARGET,
        }, sort_keys=True)

    endpoint_settings_key = '/device[{:s}]/endpoint[{:s}]/settings'.format(device_uuid, endpoint_uuid)
    for config_rule in target.service_config.config_rules:                  # pylint: disable=no-member
        if config_rule.resource_key != endpoint_settings_key: continue
        json_settings = json.loads(config_rule.resource_value)

        if 'router_id' not in json_settings:                                # missing, add it
            json_settings['router_id'] = router_id
        elif json_settings['router_id'] != router_id:                       # differs, raise exception
            msg = 'Specified RouterId({:s}) differs from Service RouterId({:s})'
            raise Exception(msg.format(str(json_settings['router_id']), str(router_id)))

        if 'route_distinguisher' not in json_settings:                      # missing, add it
            json_settings['route_distinguisher'] = route_distinguisher
        elif json_settings['route_distinguisher'] != route_distinguisher:   # differs, raise exception
            msg = 'Specified RouteDistinguisher({:s}) differs from Service RouteDistinguisher({:s})'
            raise Exception(msg.format(str(json_settings['route_distinguisher']), str(route_distinguisher)))

        if 'sub_interface_index' not in json_settings:                      # missing, add it
            json_settings['sub_interface_index'] = sub_if_index
        elif json_settings['sub_interface_index'] != sub_if_index:   # differs, raise exception
            msg = 'Specified SubInterfaceIndex({:s}) differs from Service SubInterfaceIndex({:s})'
            raise Exception(msg.format(
                str(json_settings['sub_interface_index']), str(sub_if_index)))

        if 'vlan_id' not in json_settings:                                  # missing, add it
            json_settings['vlan_id'] = cvlan_id
        elif json_settings['vlan_id'] != cvlan_id:                          # differs, raise exception
            msg = 'Specified VLANId({:s}) differs from Service VLANId({:s})'
            raise Exception(msg.format(
                str(json_settings['vlan_id']), str(cvlan_id)))

        if address_ip is not None:
            if 'address_ip' not in json_settings:                               # missing, add it
                json_settings['address_ip'] = address_ip
            elif json_settings['address_ip'] != address_ip:                     # differs, raise exception
                msg = 'Specified AddressIP({:s}) differs from Service AddressIP({:s})'
                raise Exception(msg.format(
                    str(json_settings['address_ip']), str(address_ip)))

        if address_prefix is not None:
            if 'address_prefix' not in json_settings:                           # missing, add it
                json_settings['address_prefix'] = address_prefix
            elif json_settings['address_prefix'] != address_prefix:             # differs, raise exception
                msg = 'Specified AddressPrefix({:s}) differs from Service AddressPrefix({:s})'
                raise Exception(msg.format(
                    str(json_settings['address_prefix']), str(address_prefix)))

        if address_prefix is not None:
            if 'address_prefix' not in json_settings:                           # missing, add it
                json_settings['address_prefix'] = address_prefix
            elif json_settings['address_prefix'] != address_prefix:             # differs, raise exception
                msg = 'Specified AddressPrefix({:s}) differs from Service AddressPrefix({:s})'
                raise Exception(msg.format(
                    str(json_settings['address_prefix']), str(address_prefix)))

        config_rule.resource_value = json.dumps(json_settings, sort_keys=True)
        break
    else:
        # not found, add it
        config_rule = target.service_config.config_rules.add()              # pylint: disable=no-member
        config_rule.action = ConfigActionEnum.CONFIGACTION_SET
        config_rule.resource_key = endpoint_settings_key
        resource_value = {
            'router_id': router_id,
            'route_distinguisher': route_distinguisher,
            'sub_interface_index': sub_if_index,
            'vlan_id': cvlan_id,
            'address_ip': address_ip,
            'address_prefix': address_prefix,
        }
        if access_priority is not None: resource_value['access_priority'] = access_priority
        if single_active is not None and len(single_active) > 0: resource_value['access_active'] = 'single'
        if all_active is not None and len(all_active) > 0: resource_value['access_active'] = 'all'
        config_rule.resource_value = json.dumps(resource_value, sort_keys=True)

    for constraint in target.service_constraints:                           # pylint: disable=no-member
        if constraint.constraint_type == 'diversity' and len(diversity_constraints) > 0:
            constraint_value = set(json.loads(constraint.constraint_value))
            constraint_value.update(diversity_constraints)
            constraint.constraint_value = json.dumps(sorted(list(constraint_value)), sort_keys=True)
            break
    else:
        # not found, and there are diversity constraints, add them
        if len(diversity_constraints) > 0:
            constraint = target.service_constraints.add()                   # pylint: disable=no-member
            constraint.constraint_type = 'diversity'
            constraint.constraint_value = json.dumps(sorted(list(diversity_constraints)), sort_keys=True)

    return target

def process_list_site_network_access(
        context_client : ContextClient, service_client : ServiceClient, slice_client : SliceClient,
        request_data : Dict
    ) -> Response:

    LOGGER.debug('Request: {:s}'.format(str(request_data)))
    validate_message(SCHEMA_SITE_NETWORK_ACCESS, request_data)

    errors = []
    for site_network_access in request_data['ietf-l2vpn-svc:site-network-access']:
        sna_request = process_site_network_access(context_client, site_network_access)
        LOGGER.debug('sna_request = {:s}'.format(grpc_message_to_json_string(sna_request)))
        #try:
        #    if isinstance(sna_request, Service):
        #        sna_reply = service_client.UpdateService(sna_request)
        #        if sna_reply != sna_request.service_id: # pylint: disable=no-member
        #            raise Exception('Service update failed. Wrong Service Id was returned')
        #    elif isinstance(sna_request, Slice):
        #        sna_reply = slice_client.UpdateSlice(sna_request)
        #        if sna_reply != sna_request.slice_id: # pylint: disable=no-member
        #            raise Exception('Slice update failed. Wrong Slice Id was returned')
        #    else:
        #        raise NotImplementedError('Support for Class({:s}) not implemented'.format(str(type(sna_request))))
        #except Exception as e: # pylint: disable=broad-except
        #    msg = 'Something went wrong Updating Service {:s}'
        #    LOGGER.exception(msg.format(grpc_message_to_json_string(sna_request)))
        #    errors.append({'error': str(e)})
        time.sleep(random.random() / 10.0)

    response = jsonify(errors)
    response.status_code = HTTP_NOCONTENT if len(errors) == 0 else HTTP_SERVERERROR
    return response

class L2VPN_SiteNetworkAccesses(Resource):
    @HTTP_AUTH.login_required
    def post(self, site_id : str):
        if not request.is_json: raise UnsupportedMediaType('JSON payload is required')
        LOGGER.debug('Site_Id: {:s}'.format(str(site_id)))
        context_client = ContextClient()
        service_client = ServiceClient()
        slice_client = SliceClient()
        return process_list_site_network_access(context_client, service_client, slice_client, request.json)

    @HTTP_AUTH.login_required
    def put(self, site_id : str):
        if not request.is_json: raise UnsupportedMediaType('JSON payload is required')
        LOGGER.debug('Site_Id: {:s}'.format(str(site_id)))
        context_client = ContextClient()
        service_client = ServiceClient()
        slice_client = SliceClient()
        return process_list_site_network_access(context_client, service_client, slice_client, request.json)
