# Copyright 2022-2023 ETSI TeraFlowSDN - TFS OSG (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
import json
import grpc
from collections import defaultdict
from flask import current_app, redirect, render_template, Blueprint, flash, session, url_for, request
from common.proto.context_pb2 import (
    IsolationLevelEnum, Service, ServiceId, ServiceTypeEnum, ServiceStatusEnum, Connection, Empty, DeviceDriverEnum,
    ConfigActionEnum, Device, DeviceList)
from common.tools.context_queries.Context import get_context
from common.tools.context_queries.Topology import get_topology
from common.tools.context_queries.EndPoint import get_endpoint_names
from common.tools.context_queries.Service import get_service_by_uuid
from common.tools.object_factory.ConfigRule import json_config_rule_set
from common.tools.object_factory.Context import json_context_id
from common.tools.object_factory.Topology import json_topology_id
from context.client.ContextClient import ContextClient
from service.client.ServiceClient import ServiceClient
from typing import Optional, Set

service = Blueprint('service', __name__, url_prefix='/service')

context_client = ContextClient()
service_client = ServiceClient()

@contextmanager
def connected_client(c):
    try:
        c.connect()
        yield c
    finally:
        c.close()

# Context client must be in connected state when calling this function
def get_device_drivers_in_use(topology_uuid: str, context_uuid: str) -> Set[str]:
    active_drivers = set()
    grpc_topology = get_topology(context_client, topology_uuid, context_uuid=context_uuid, rw_copy=False)
    topo_device_uuids = {device_id.device_uuid.uuid for device_id in grpc_topology.device_ids}
    grpc_devices: DeviceList = context_client.ListDevices(Empty())
    for device in grpc_devices.devices:
        if device.device_id.device_uuid.uuid in topo_device_uuids:
            for driver in device.device_drivers:
                active_drivers.add(DeviceDriverEnum.Name(driver))
    return active_drivers

@service.get('/')
def home():
    if 'context_uuid' not in session or 'topology_uuid' not in session:
        flash("Please select a context!", "warning")
        return redirect(url_for("main.home"))
    context_uuid = session['context_uuid']
    topology_uuid = session['topology_uuid']

    context_client.connect()

    context_obj = get_context(context_client, context_uuid, rw_copy=False)
    if context_obj is None:
        flash('Context({:s}) not found'.format(str(context_uuid)), 'danger')
        services, device_names, endpoints_data = list(), list(), list()
    else:
        try:
            services = context_client.ListServices(context_obj.context_id)
            services = services.services
            active_drivers = get_device_drivers_in_use(topology_uuid, context_uuid)
        except grpc.RpcError as e:
            if e.code() != grpc.StatusCode.NOT_FOUND: raise
            if e.details() != 'Context({:s}) not found'.format(context_uuid): raise
            services, device_names, endpoints_data = list(), dict(), dict()
            active_drivers = set()
        else:
            endpoint_ids = list()
            for service_ in services:
                endpoint_ids.extend(service_.service_endpoint_ids)
            device_names, endpoints_data = get_endpoint_names(context_client, endpoint_ids)

    context_client.close()
    return render_template(
        'service/home.html', services=services, device_names=device_names, endpoints_data=endpoints_data,
        ste=ServiceTypeEnum, sse=ServiceStatusEnum, active_drivers=active_drivers)


@service.route('add', methods=['GET', 'POST'])
def add():
    flash('Add service route called', 'danger')
    raise NotImplementedError()
    #return render_template('service/home.html')

def get_hub_module_name(dev: Device) -> Optional[str]:
    for cr in dev.device_config.config_rules:
        if cr.action == ConfigActionEnum.CONFIGACTION_SET and cr.custom and cr.custom.resource_key == "_connect/settings":
            try:
                cr_dict = json.loads(cr.custom.resource_value)
                if "hub_module_name" in cr_dict:
                    return cr_dict["hub_module_name"]
            except json.JSONDecodeError:
                pass
    return None

@service.route('add-xr', methods=['GET', 'POST'])
def add_xr():
    ### FIXME: copypaste
    if 'context_uuid' not in session or 'topology_uuid' not in session:
        flash("Please select a context!", "warning")
        return redirect(url_for("main.home"))

    context_uuid = session['context_uuid']
    topology_uuid = session['topology_uuid']

    context_client.connect()
    grpc_topology = get_topology(context_client, topology_uuid, context_uuid=context_uuid, rw_copy=False)
    if grpc_topology is None:
        flash('Context({:s})/Topology({:s}) not found'.format(str(context_uuid), str(topology_uuid)), 'danger')
        return redirect(url_for("main.home"))
    else:
        topo_device_uuids = {device_id.device_uuid.uuid for device_id in grpc_topology.device_ids}
        grpc_devices= context_client.ListDevices(Empty())
        devices = [
            device for device in grpc_devices.devices
            if device.device_id.device_uuid.uuid in topo_device_uuids and DeviceDriverEnum.DEVICEDRIVER_XR in device.device_drivers
        ]
        devices.sort(key=lambda dev: dev.name)

        hub_interfaces_by_device = defaultdict(list)
        leaf_interfaces_by_device = defaultdict(list)
        constellation_name_to_uuid = {}
        dev_ep_to_uuid = {}
        ep_uuid_to_name = {}
        for d in devices:
            constellation_name_to_uuid[d.name] = d.device_id.device_uuid.uuid
            hm_name = get_hub_module_name(d)
            if hm_name is not None:
                hm_if_prefix= hm_name + "|"
                for ep in d.device_endpoints:
                    dev_ep_to_uuid[(d.name, ep.name)] = ep.endpoint_id.endpoint_uuid.uuid
                    if ep.name.startswith(hm_if_prefix):
                        hub_interfaces_by_device[d.name].append(ep.name)
                    else:
                        leaf_interfaces_by_device[d.name].append(ep.name)
                    ep_uuid_to_name[ep.endpoint_id.endpoint_uuid.uuid] = (d.name, ep.name)
                hub_interfaces_by_device[d.name].sort()
                leaf_interfaces_by_device[d.name].sort()

        # Find out what endpoints are already used so that they can be disabled
        # in the create screen
        context_obj = get_context(context_client, context_uuid, rw_copy=False)
        if context_obj is None:
            flash('Context({:s}) not found'.format(str(context_uuid)), 'danger')
            return redirect(request.url)
        
        services = context_client.ListServices(context_obj.context_id)
        ep_used_by={}
        for service in services.services:
            if  service.service_type == ServiceTypeEnum.SERVICETYPE_TAPI_CONNECTIVITY_SERVICE:
                for ep in service.service_endpoint_ids:
                    ep_uuid = ep.endpoint_uuid.uuid
                    if ep_uuid in ep_uuid_to_name:
                        dev_name, ep_name = ep_uuid_to_name[ep_uuid]
                        ep_used_by[f"{ep_name}@{dev_name}"] = service.name

    context_client.close()

    if request.method != 'POST':
        return render_template('service/add-xr.html', devices=devices, hub_if=hub_interfaces_by_device, leaf_if=leaf_interfaces_by_device, ep_used_by=ep_used_by)
    else:
        service_name = request.form["service_name"]
        if service_name == "":
            flash(f"Service name must be specified", 'danger')

        constellation = request.form["constellation"]
        constellation_uuid = constellation_name_to_uuid.get(constellation, None)
        if constellation_uuid is None:
            flash(f"Invalid constellation \"{constellation}\"", 'danger')

        hub_if = request.form["hubif"]
        hub_if_uuid = dev_ep_to_uuid.get((constellation, hub_if), None)
        if hub_if_uuid is None:
            flash(f"Invalid hub interface \"{hub_if}\"", 'danger')

        leaf_if = request.form["leafif"]
        leaf_if_uuid = dev_ep_to_uuid.get((constellation, leaf_if), None)
        if leaf_if_uuid is None:
            flash(f"Invalid leaf interface \"{leaf_if}\"", 'danger')
        
        if service_name == "" or constellation_uuid is None or hub_if_uuid is None or leaf_if_uuid is None:
            return redirect(request.url)
        
        
        json_context_uuid=json_context_id(context_uuid)
        sr = {
            "name": service_name,
            "service_id": {
                 "context_id": {"context_uuid": {"uuid": context_uuid}},
                 "service_uuid": {"uuid": service_name}
            },
            'service_type'        : ServiceTypeEnum.SERVICETYPE_TAPI_CONNECTIVITY_SERVICE,
            "service_endpoint_ids": [
                {'device_id': {'device_uuid': {'uuid': constellation_uuid}}, 'endpoint_uuid': {'uuid': hub_if_uuid}, 'topology_id': json_topology_id("admin", context_id=json_context_uuid)},
                {'device_id': {'device_uuid': {'uuid': constellation_uuid}}, 'endpoint_uuid': {'uuid': leaf_if_uuid}, 'topology_id': json_topology_id("admin", context_id=json_context_uuid)}
            ],
            'service_status'      : {'service_status': ServiceStatusEnum.SERVICESTATUS_PLANNED},
            'service_constraints' : [],
        }

        json_tapi_settings = {
            'capacity_value'  : 50.0,
            'capacity_unit'   : 'GHz',
            'layer_proto_name': 'PHOTONIC_MEDIA',
            'layer_proto_qual': 'tapi-photonic-media:PHOTONIC_LAYER_QUALIFIER_NMC',
            'direction'       : 'UNIDIRECTIONAL',
        }
        config_rule = json_config_rule_set('/settings', json_tapi_settings)

        with connected_client(service_client) as sc:
            endpoints, sr['service_endpoint_ids'] = sr['service_endpoint_ids'], []
            try:
                create_response = sc.CreateService(Service(**sr))
            except Exception as e:
                flash(f'Failure to update service name {service_name} with endpoints and configuration, exception {str(e)}', 'danger')
                return redirect(request.url)
            
            sr['service_endpoint_ids'] = endpoints
            sr['service_config'] = {'config_rules': [config_rule]}

            try:
                update_response = sc.UpdateService(Service(**sr))
                flash(f'Created service {update_response.service_uuid.uuid}', 'success')
            except Exception as e: 
                flash(f'Failure to update service {create_response.service_uuid.uuid} with endpoints and configuration, exception {str(e)}', 'danger')
                return redirect(request.url)

            return redirect(url_for('service.home'))

@service.get('<path:service_uuid>/detail')
def detail(service_uuid: str):
    if 'context_uuid' not in session or 'topology_uuid' not in session:
        flash("Please select a context!", "warning")
        return redirect(url_for("main.home"))
    context_uuid = session['context_uuid']

    try:
        context_client.connect()

        endpoint_ids = list()
        service_obj = get_service_by_uuid(context_client, service_uuid, rw_copy=False)
        if service_obj is None:
            flash('Context({:s})/Service({:s}) not found'.format(str(context_uuid), str(service_uuid)), 'danger')
            service_obj = Service()
        else:
            endpoint_ids.extend(service_obj.service_endpoint_ids)
            connections: Connection = context_client.ListConnections(service_obj.service_id)
            connections = connections.connections
            for connection in connections: endpoint_ids.extend(connection.path_hops_endpoint_ids)

        if len(endpoint_ids) > 0:
            device_names, endpoints_data = get_endpoint_names(context_client, endpoint_ids)
        else:
            device_names, endpoints_data = dict(), dict()

        context_client.close()

        return render_template(
            'service/detail.html', service=service_obj, connections=connections, device_names=device_names,
            endpoints_data=endpoints_data, ste=ServiceTypeEnum, sse=ServiceStatusEnum, ile=IsolationLevelEnum)
    except Exception as e:
        flash('The system encountered an error and cannot show the details of this service.', 'warning')
        current_app.logger.exception(e)
        return redirect(url_for('service.home'))


@service.get('<path:service_uuid>/delete')
def delete(service_uuid: str):
    if 'context_uuid' not in session or 'topology_uuid' not in session:
        flash("Please select a context!", "warning")
        return redirect(url_for("main.home"))
    context_uuid = session['context_uuid']

    try:
        request = ServiceId()
        request.service_uuid.uuid = service_uuid
        request.context_id.context_uuid.uuid = context_uuid
        service_client.connect()
        service_client.DeleteService(request)
        service_client.close()

        flash('Service "{:s}" deleted successfully!'.format(service_uuid), 'success')
    except Exception as e:
        flash('Problem deleting service "{:s}": {:s}'.format(service_uuid, str(e.details())), 'danger')
        current_app.logger.exception(e)
    return redirect(url_for('service.home'))
