# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json, logging, re
from flask import jsonify, redirect, render_template, Blueprint, flash, session, url_for, request
from common.proto.context_pb2 import (
    Connection, Context, Device, Empty, Link, Service, Slice, Topology, ContextIdList, TopologyId, TopologyIdList)
from common.tools.grpc.Tools import grpc_message_to_json_string
from common.tools.object_factory.Context import json_context_id
from common.tools.object_factory.Topology import json_topology_id
from context.client.ContextClient import ContextClient
from device.client.DeviceClient import DeviceClient
from service.client.ServiceClient import ServiceClient
from slice.client.SliceClient import SliceClient
from webui.service.main.DescriptorTools import (
    format_custom_config_rules, get_descriptors_add_contexts, get_descriptors_add_services, get_descriptors_add_slices,
    get_descriptors_add_topologies, split_devices_by_rules)
from webui.service.main.forms import ContextTopologyForm, DescriptorForm

main = Blueprint('main', __name__)

context_client = ContextClient()
device_client = DeviceClient()
service_client = ServiceClient()
slice_client = SliceClient()

logger = logging.getLogger(__name__)

ENTITY_TO_TEXT = {
    # name   => singular,    plural
    'context'   : ('Context',    'Contexts'   ),
    'topology'  : ('Topology',   'Topologies' ),
    'device'    : ('Device',     'Devices'    ),
    'link'      : ('Link',       'Links'      ),
    'service'   : ('Service',    'Services'   ),
    'slice'     : ('Slice',      'Slices'     ),
    'connection': ('Connection', 'Connections'),
}

ACTION_TO_TEXT = {
    # action =>  infinitive,  past
    'add'     : ('Add',       'Added'),
    'update'  : ('Update',    'Updated'),
    'config'  : ('Configure', 'Configured'),
}

def process_descriptor(entity_name, action_name, grpc_method, grpc_class, entities):
    entity_name_singluar,entity_name_plural = ENTITY_TO_TEXT[entity_name]
    action_infinitive, action_past = ACTION_TO_TEXT[action_name]
    num_ok, num_err = 0, 0
    for entity in entities:
        try:
            grpc_method(grpc_class(**entity))
            num_ok += 1
        except Exception as e: # pylint: disable=broad-except
            flash(f'Unable to {action_infinitive} {entity_name_singluar} {str(entity)}: {str(e)}', 'error')
            num_err += 1
    if num_ok : flash(f'{str(num_ok)} {entity_name_plural} {action_past}', 'success')
    if num_err: flash(f'{str(num_err)} {entity_name_plural} failed', 'danger')

def process_descriptors(descriptors):
    try:
        descriptors_file = request.files[descriptors.name]
        descriptors_data = descriptors_file.read()
        descriptors = json.loads(descriptors_data)
    except Exception as e: # pylint: disable=broad-except
        flash(f'Unable to load descriptor file: {str(e)}', 'danger')
        return

    dummy_mode  = descriptors.get('dummy_mode' , False)
    contexts    = descriptors.get('contexts'   , [])
    topologies  = descriptors.get('topologies' , [])
    devices     = descriptors.get('devices'    , [])
    links       = descriptors.get('links'      , [])
    services    = descriptors.get('services'   , [])
    slices      = descriptors.get('slices'     , [])
    connections = descriptors.get('connections', [])

    # Format CustomConfigRules in Devices, Services and Slices provided in JSON format
    for device in devices:
        config_rules = device.get('device_config', {}).get('config_rules', [])
        config_rules = format_custom_config_rules(config_rules)
        device['device_config']['config_rules'] = config_rules

    for service in services:
        config_rules = service.get('service_config', {}).get('config_rules', [])
        config_rules = format_custom_config_rules(config_rules)
        service['service_config']['config_rules'] = config_rules

    for slice in slices:
        config_rules = slice.get('slice_config', {}).get('config_rules', [])
        config_rules = format_custom_config_rules(config_rules)
        slice['slice_config']['config_rules'] = config_rules


    # Context and Topology require to create the entity first, and add devices, links, services, slices, etc. in a
    # second stage.
    contexts_add = get_descriptors_add_contexts(contexts)
    topologies_add = get_descriptors_add_topologies(topologies)

    if dummy_mode:
        # Dummy Mode: used to pre-load databases (WebUI debugging purposes) with no smart or automated tasks.
        context_client.connect()
        process_descriptor('context',    'add',    context_client.SetContext,    Context,    contexts_add  )
        process_descriptor('topology',   'add',    context_client.SetTopology,   Topology,   topologies_add)
        process_descriptor('device',     'add',    context_client.SetDevice,     Device,     devices       )
        process_descriptor('link',       'add',    context_client.SetLink,       Link,       links         )
        process_descriptor('service',    'add',    context_client.SetService,    Service,    services      )
        process_descriptor('slice',      'add',    context_client.SetSlice,      Slice,      slices        )
        process_descriptor('connection', 'add',    context_client.SetConnection, Connection, connections   )
        process_descriptor('context',    'update', context_client.SetContext,    Context,    contexts      )
        process_descriptor('topology',   'update', context_client.SetTopology,   Topology,   topologies    )
        context_client.close()
    else:
        # Normal mode: follows the automated workflows in the different components
        assert len(connections) == 0, 'in normal mode, connections should not be set'

        # Device, Service and Slice require to first create the entity and the configure it
        devices_add, devices_config = split_devices_by_rules(devices)
        services_add = get_descriptors_add_services(services)
        slices_add = get_descriptors_add_slices(slices)

        context_client.connect()
        device_client.connect()
        service_client.connect()
        slice_client.connect()

        process_descriptor('context',    'add',    context_client.SetContext,      Context,    contexts_add  )
        process_descriptor('topology',   'add',    context_client.SetTopology,     Topology,   topologies_add)
        process_descriptor('device',     'add',    device_client .AddDevice,       Device,     devices_add   )
        process_descriptor('device',     'config', device_client .ConfigureDevice, Device,     devices_config)
        process_descriptor('link',       'add',    context_client.SetLink,         Link,       links         )
        process_descriptor('service',    'add',    service_client.CreateService,   Service,    services_add  )
        process_descriptor('service',    'update', service_client.UpdateService,   Service,    services      )
        process_descriptor('slice',      'add',    slice_client  .CreateSlice,     Slice,      slices_add    )
        process_descriptor('slice',      'update', slice_client  .UpdateSlice,     Slice,      slices        )
        process_descriptor('context',    'update', context_client.SetContext,      Context,    contexts      )
        process_descriptor('topology',   'update', context_client.SetTopology,     Topology,   topologies    )

        slice_client.close()
        service_client.close()
        device_client.close()
        context_client.close()

@main.route('/', methods=['GET', 'POST'])
def home():
    context_client.connect()
    device_client.connect()
    context_topology_form: ContextTopologyForm = ContextTopologyForm()
    context_topology_form.context_topology.choices.append(('', 'Select...'))

    ctx_response: ContextIdList = context_client.ListContextIds(Empty())
    for context_id in ctx_response.context_ids:
        context_uuid = context_id.context_uuid.uuid
        topo_response: TopologyIdList = context_client.ListTopologyIds(context_id)
        for topology_id in topo_response.topology_ids:
            topology_uuid = topology_id.topology_uuid.uuid
            context_topology_uuid  = 'ctx[{:s}]/topo[{:s}]'.format(context_uuid, topology_uuid)
            context_topology_name  = 'Context({:s}):Topology({:s})'.format(context_uuid, topology_uuid)
            context_topology_entry = (context_topology_uuid, context_topology_name)
            context_topology_form.context_topology.choices.append(context_topology_entry)

    if context_topology_form.validate_on_submit():
        context_topology_uuid = context_topology_form.context_topology.data
        if len(context_topology_uuid) > 0:
            match = re.match('ctx\[([^\]]+)\]\/topo\[([^\]]+)\]', context_topology_uuid)
            if match is not None:
                session['context_topology_uuid'] = context_topology_uuid = match.group(0)
                session['context_uuid'] = context_uuid = match.group(1)
                session['topology_uuid'] = topology_uuid = match.group(2)
                MSG = f'Context({context_uuid})/Topology({topology_uuid}) successfully selected.'
                flash(MSG, 'success')
                return redirect(url_for("main.home"))

    if 'context_topology_uuid' in session:
        context_topology_form.context_topology.data = session['context_topology_uuid']

    descriptor_form: DescriptorForm = DescriptorForm()
    try:
        if descriptor_form.validate_on_submit():
            process_descriptors(descriptor_form.descriptors)
            return redirect(url_for("main.home"))
    except Exception as e:
        logger.exception('Descriptor load failed')
        flash(f'Descriptor load failed: `{str(e)}`', 'danger')
    finally:
        context_client.close()
        device_client.close()

    return render_template(
        'main/home.html', context_topology_form=context_topology_form, descriptor_form=descriptor_form)

@main.route('/topology', methods=['GET'])
def topology():
    context_client.connect()
    try:
        if 'context_topology_uuid' not in session:
            return jsonify({'devices': [], 'links': []})

        context_uuid = session['context_uuid']
        topology_uuid = session['topology_uuid']

        json_topo_id = json_topology_id(topology_uuid, context_id=json_context_id(context_uuid))
        grpc_topology = context_client.GetTopology(TopologyId(**json_topo_id))

        topo_device_uuids = {device_id.device_uuid.uuid for device_id in grpc_topology.device_ids}
        topo_link_uuids   = {link_id  .link_uuid  .uuid for link_id   in grpc_topology.link_ids  }

        response = context_client.ListDevices(Empty())
        devices = []
        for device in response.devices:
            if device.device_id.device_uuid.uuid not in topo_device_uuids: continue
            devices.append({
                'id': device.device_id.device_uuid.uuid,
                'name': device.device_id.device_uuid.uuid,
                'type': device.device_type,
            })

        response = context_client.ListLinks(Empty())
        links = []
        for link in response.links:
            if link.link_id.link_uuid.uuid not in topo_link_uuids: continue
            if len(link.link_endpoint_ids) != 2:
                str_link = grpc_message_to_json_string(link)
                logger.warning('Unexpected link with len(endpoints) != 2: {:s}'.format(str_link))
                continue
            links.append({
                'id': link.link_id.link_uuid.uuid,
                'source': link.link_endpoint_ids[0].device_id.device_uuid.uuid,
                'target': link.link_endpoint_ids[1].device_id.device_uuid.uuid,
            })

        return jsonify({'devices': devices, 'links': links})
    except:
        logger.exception('Error retrieving topology')
    finally:
        context_client.close()

@main.get('/about')
def about():
    return render_template('main/about.html')

@main.get('/debug')
def debug():
    return render_template('main/debug.html')

@main.get('/resetsession')
def reset_session():
    session.clear()
    return redirect(url_for("main.home"))
