# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import sys
from flask import jsonify, redirect, render_template, Blueprint, flash, session, url_for, request
from webui.Config import (CONTEXT_SERVICE_ADDRESS, CONTEXT_SERVICE_PORT,
                DEVICE_SERVICE_ADDRESS, DEVICE_SERVICE_PORT)
from context.client.ContextClient import ContextClient
from device.client.DeviceClient import DeviceClient
from webui.proto.context_pb2 import Context, Device, Empty, Link, Topology
from webui.service.main.forms import ContextForm, DescriptorForm

main = Blueprint('main', __name__)

context_client: ContextClient = ContextClient(CONTEXT_SERVICE_ADDRESS, CONTEXT_SERVICE_PORT)
device_client: DeviceClient = DeviceClient(DEVICE_SERVICE_ADDRESS, DEVICE_SERVICE_PORT)

logger = logging.getLogger(__name__)

def process_descriptor(item_name_singluar, item_name_plural, grpc_method, grpc_class, items):
    num_ok, num_err = 0, 0
    for item in items:
        try:
            grpc_method(grpc_class(**item))
            num_ok += 1
        except Exception as e: # pylint: disable=broad-except
            flash(f'Unable to add {item_name_singluar} {str(item)}: {str(e)}', 'error')
            num_err += 1
    if num_ok : flash(f'{str(num_ok)} {item_name_plural} added', 'success')
    if num_err: flash(f'{str(num_err)} {item_name_plural} failed', 'danger')

def process_descriptors(descriptors):
    logger.warning(str(descriptors.data))
    logger.warning(str(descriptors.name))
    try:
        logger.warning(str(request.files))
        descriptors_file = request.files[descriptors.name]
        logger.warning(str(descriptors_file))
        descriptors_data = descriptors_file.read()
        logger.warning(str(descriptors_data))
        descriptors = json.loads(descriptors_data)
        logger.warning(str(descriptors))
    except Exception as e: # pylint: disable=broad-except
        flash(f'Unable to load descriptor file: {str(e)}', 'danger')
        return

    process_descriptor('Context',  'Contexts',   context_client.SetContext,  Context,  descriptors['contexts'  ])
    process_descriptor('Topology', 'Topologies', context_client.SetTopology, Topology, descriptors['topologies'])
    process_descriptor('Device',   'Devices',    device_client .AddDevice,   Device,   descriptors['devices'   ])
    process_descriptor('Link',     'Links',      context_client.SetLink,     Link,     descriptors['links'     ])

@main.route('/', methods=['GET', 'POST'])
def home():
    context_client.connect()
    device_client.connect()
    response = context_client.ListContextIds(Empty())
    context_form: ContextForm = ContextForm()
    context_form.context.choices.append(('', 'Select...'))
    for context in response.context_ids:
        context_form.context.choices.append((context.context_uuid.uuid, context.context_uuid))
    if context_form.validate_on_submit():
        session['context_uuid'] = context_form.context.data
        flash(f'The context was successfully set to `{context_form.context.data}`.', 'success')
        return redirect(url_for("main.home"))
    if 'context_uuid' in session:
        context_form.context.data = session['context_uuid']
    descriptor_form: DescriptorForm = DescriptorForm()
    try:
        if descriptor_form.validate_on_submit():
            process_descriptors(descriptor_form.descriptors)
            return redirect(url_for("main.home"))
    except Exception as e:
        logger.exception('Descriptor load failed')
        flash(f'Descriptor load failed: `{str(e)}`', 'danger')
    finally:
        context_client.close()
        device_client.close()
    return render_template('main/home.html', context_form=context_form, descriptor_form=descriptor_form)

@main.route('/topology', methods=['GET'])
def topology():
    context_client.connect()
    try:
        response = context_client.ListDevices(Empty())
        devices = [{
            'id': device.device_id.device_uuid.uuid,
            'name': device.device_id.device_uuid.uuid,
            'type': device.device_type,
        } for device in response.devices]

        response = context_client.ListLinks(Empty())
        links = [{
            'id': link.link_id.link_uuid.uuid,
            'source': link.link_endpoint_ids[0].device_id.device_uuid.uuid,
            'target': link.link_endpoint_ids[1].device_id.device_uuid.uuid,
        } for link in response.links]

        return jsonify({'devices': devices, 'links': links})
    except:
        logger.exception('Error retrieving topology')
    finally:
        context_client.close()

@main.get('/about')
def about():
    return render_template('main/about.html')

@main.get('/debug')
def debug():
    return render_template('main/debug.html')

@main.get('/resetsession')
def reset_session():
    session.clear()
    return redirect(url_for("main.home"))
