# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging, random, time
from typing import Dict, Optional, Union
from flask import request
from flask.json import jsonify
from flask.wrappers import Response
from flask_restful import Resource
from werkzeug.exceptions import UnsupportedMediaType
from common.proto.context_pb2 import Service, Slice
from common.tools.grpc.ConfigRules import update_config_rule_custom
from common.tools.grpc.Constraints import (
    update_constraint_custom, update_constraint_endpoint_location, update_constraint_endpoint_priority,
    update_constraint_sla_availability)
from common.tools.grpc.EndPointIds import update_endpoint_ids
from common.tools.grpc.Tools import grpc_message_to_json_string
from context.client.ContextClient import ContextClient
from service.client.ServiceClient import ServiceClient
from slice.client.SliceClient import SliceClient
from .schemas.site_network_access import SCHEMA_SITE_NETWORK_ACCESS
from .tools.Authentication import HTTP_AUTH
from .tools.ContextMethods import get_service, get_slice
from .tools.HttpStatusCodes import HTTP_NOCONTENT, HTTP_SERVERERROR
from .tools.Validator import validate_message
from .Constants import (
    BEARER_MAPPINGS, DEFAULT_ADDRESS_FAMILIES, DEFAULT_BGP_AS, DEFAULT_BGP_ROUTE_TARGET, DEFAULT_MTU)

LOGGER = logging.getLogger(__name__)

def process_site_network_access(context_client : ContextClient, site_id : str, site_network_access : Dict) -> Service:
    vpn_id = site_network_access['vpn-attachment']['vpn-id']
    encapsulation_type = site_network_access['connection']['encapsulation-type']
    cvlan_id = site_network_access['connection']['tagged-interface'][encapsulation_type]['cvlan-id']

    bearer_reference = site_network_access['bearer']['bearer-reference']

    access_priority : Optional[int] = site_network_access.get('availability', {}).get('access-priority')
    single_active   : bool = len(site_network_access.get('availability', {}).get('single-active', [])) > 0
    all_active      : bool = len(site_network_access.get('availability', {}).get('all-active', [])) > 0

    diversity_constraints = site_network_access.get('access-diversity', {}).get('constraints', {}).get('constraint', [])
    raise_if_differs = True
    diversity_constraints = {
        constraint['constraint-type']:([
            target[0]
            for target in constraint['target'].items()
            if len(target[1]) == 1
        ][0], raise_if_differs)
        for constraint in diversity_constraints
    }

    mapping = BEARER_MAPPINGS.get(bearer_reference)
    if mapping is None:
        msg = 'Specified Bearer({:s}) is not configured.'
        raise Exception(msg.format(str(bearer_reference)))
    device_uuid,endpoint_uuid,router_id,route_distinguisher,sub_if_index,address_ip,address_prefix = mapping

    target : Union[Service, Slice, None] = None
    if target is None: target = get_slice  (context_client, vpn_id)
    if target is None: target = get_service(context_client, vpn_id)
    if target is None: raise Exception('VPN({:s}) not found in database'.format(str(vpn_id)))

    if isinstance(target, Service):
        endpoint_ids = target.service_endpoint_ids        # pylint: disable=no-member
        config_rules = target.service_config.config_rules # pylint: disable=no-member
        constraints  = target.service_constraints         # pylint: disable=no-member
    elif isinstance(target, Slice):
        endpoint_ids = target.slice_endpoint_ids        # pylint: disable=no-member
        config_rules = target.slice_config.config_rules # pylint: disable=no-member
        constraints  = target.slice_constraints         # pylint: disable=no-member
    else:
        raise Exception('Target({:s}) not supported'.format(str(target.__class__.__name__)))

    endpoint_id = update_endpoint_ids(endpoint_ids, device_uuid, endpoint_uuid)

    service_settings_key = '/settings'
    update_config_rule_custom(config_rules, service_settings_key, {
        'mtu'             : (DEFAULT_MTU,              True),
        'address_families': (DEFAULT_ADDRESS_FAMILIES, True),
        'bgp_as'          : (DEFAULT_BGP_AS,           True),
        'bgp_route_target': (DEFAULT_BGP_ROUTE_TARGET, True),
    })

    endpoint_settings_key = '/device[{:s}]/endpoint[{:s}]/settings'.format(device_uuid, endpoint_uuid)
    field_updates = {
        'router_id'          : (router_id,           True),
        'route_distinguisher': (route_distinguisher, True),
        'sub_interface_index': (sub_if_index,        True),
        'vlan_id'            : (cvlan_id,            True),
    }
    if address_ip      is not None: field_updates['address_ip'     ] = (address_ip,      True)
    if address_prefix  is not None: field_updates['address_prefix' ] = (address_prefix,  True)
    update_config_rule_custom(config_rules, endpoint_settings_key, field_updates)

    field_updates = {}
    if len(diversity_constraints) > 0:
        field_updates.update(diversity_constraints)
    update_constraint_custom(constraints, 'diversity', field_updates)

    update_constraint_endpoint_location(constraints, endpoint_id, region=site_id)
    if access_priority is not None: update_constraint_endpoint_priority(constraints, endpoint_id, access_priority)
    if single_active or all_active:
        # assume 1 disjoint path per endpoint/location included in service/slice
        location_endpoints = {}
        for constraint in constraints:
            if constraint.WhichOneof('constraint') != 'endpoint_location': continue
            str_endpoint_id = grpc_message_to_json_string(constraint.endpoint_location.endpoint_id)
            str_location_id = grpc_message_to_json_string(constraint.endpoint_location.location)
            location_endpoints.setdefault(str_location_id, set()).add(str_endpoint_id)
        num_endpoints_per_location = {len(endpoints) for endpoints in location_endpoints.values()}
        num_disjoint_paths = min(num_endpoints_per_location)
        update_constraint_sla_availability(constraints, num_disjoint_paths, all_active)

    return target

def process_list_site_network_access(
        context_client : ContextClient, service_client : ServiceClient, slice_client : SliceClient, site_id : str,
        request_data : Dict
    ) -> Response:

    LOGGER.debug('Request: {:s}'.format(str(request_data)))
    validate_message(SCHEMA_SITE_NETWORK_ACCESS, request_data)

    errors = []
    for site_network_access in request_data['ietf-l2vpn-svc:site-network-access']:
        sna_request = process_site_network_access(context_client, site_id, site_network_access)
        LOGGER.debug('sna_request = {:s}'.format(grpc_message_to_json_string(sna_request)))
        try:
            if isinstance(sna_request, Service):
                sna_reply = service_client.UpdateService(sna_request)
                if sna_reply != sna_request.service_id: # pylint: disable=no-member
                    raise Exception('Service update failed. Wrong Service Id was returned')
            elif isinstance(sna_request, Slice):
                sna_reply = slice_client.UpdateSlice(sna_request)
                if sna_reply != sna_request.slice_id: # pylint: disable=no-member
                    raise Exception('Slice update failed. Wrong Slice Id was returned')
            else:
                raise NotImplementedError('Support for Class({:s}) not implemented'.format(str(type(sna_request))))
        except Exception as e: # pylint: disable=broad-except
            msg = 'Something went wrong Updating VPN {:s}'
            LOGGER.exception(msg.format(grpc_message_to_json_string(sna_request)))
            errors.append({'error': str(e)})
        time.sleep(random.random() / 10.0)

    response = jsonify(errors)
    response.status_code = HTTP_NOCONTENT if len(errors) == 0 else HTTP_SERVERERROR
    return response

class L2VPN_SiteNetworkAccesses(Resource):
    @HTTP_AUTH.login_required
    def post(self, site_id : str):
        if not request.is_json: raise UnsupportedMediaType('JSON payload is required')
        LOGGER.debug('Site_Id: {:s}'.format(str(site_id)))
        context_client = ContextClient()
        service_client = ServiceClient()
        slice_client = SliceClient()
        return process_list_site_network_access(context_client, service_client, slice_client, site_id, request.json)

    @HTTP_AUTH.login_required
    def put(self, site_id : str):
        if not request.is_json: raise UnsupportedMediaType('JSON payload is required')
        LOGGER.debug('Site_Id: {:s}'.format(str(site_id)))
        context_client = ContextClient()
        service_client = ServiceClient()
        slice_client = SliceClient()
        return process_list_site_network_access(context_client, service_client, slice_client, site_id, request.json)
