# Copyright 2022-2024 ETSI OSG/SDG TeraFlowSDN (TFS) (https://tfs.etsi.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime, logging, uuid
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, selectinload, sessionmaker
from sqlalchemy_cockroachdb import run_transaction
from typing import Dict, List, Optional, Set, Tuple
from common.method_wrappers.ServiceExceptions import InvalidArgumentException, NotFoundException
from common.message_broker.MessageBroker import MessageBroker
from common.proto.context_pb2 import Empty
from common.proto.qkd_app_pb2 import (
    AppList, App, AppId)
from common.tools.grpc.Tools import grpc_message_to_json_string
from .models.QKDAppModel import AppModel
from .models.enums.QKDAppStatus import grpc_to_enum__qkd_app_status
from .models.enums.QKDAppTypes import grpc_to_enum__qkd_app_types
from .uuids.QKDApp import app_get_uuid
from common.tools.object_factory.Context import json_context_id
from common.tools.object_factory.QKDApp import json_app_id
from context.service.database.uuids.Context import context_get_uuid



#from .Events import notify_event_context, notify_event_device, notify_event_topology

LOGGER = logging.getLogger(__name__)


def app_list_objs(db_engine : Engine) -> AppList:
    def callback(session : Session) -> List[Dict]:
        obj_list : List[AppModel] = session.query(AppModel)\
            .all()
        return [obj.dump() for obj in obj_list]
    apps = run_transaction(sessionmaker(bind=db_engine), callback)
    return AppList(apps=apps)

def app_get(db_engine : Engine, request : AppId) -> App:
    app_uuid = app_get_uuid(request, allow_random=False)
    def callback(session : Session) -> Optional[Dict]:
        obj : Optional[AppModel] = session.query(AppModel)\
            .filter_by(app_uuid=app_uuid).one_or_none()
        return None if obj is None else obj.dump()
    obj = run_transaction(sessionmaker(bind=db_engine), callback)
    if obj is None:
        raw_app_uuid = request.app_uuid.uuid
        raise NotFoundException('App', raw_app_uuid, extra_details=[
            'app_uuid generated was: {:s}'.format(app_uuid)
        ])
    return App(**obj)

def app_set(db_engine : Engine, messagebroker : MessageBroker, request : App) -> AppId:
    context_uuid = context_get_uuid(request.app_id.context_id, allow_random=False)
    raw_app_uuid = request.app_id.app_uuid.uuid
    app_uuid = app_get_uuid(request.app_id, allow_random=True)

    app_type = request.app_type
    app_status = grpc_to_enum__qkd_app_status(request.app_status)
    app_type = grpc_to_enum__qkd_app_types(request.app_type)

    now = datetime.datetime.utcnow()

    
    app_data = [{
        'context_uuid'       : context_uuid,
        'app_uuid'           : app_uuid,
        'app_status'         : app_status,
        'app_type'           : app_type,
        'server_app_id'      : request.server_app_id,
        'client_app_id'      : request.client_app_id,
        'backing_qkdl_uuid'  : [qkdl_id.qkdl_uuid.uuid for qkdl_id in request.backing_qkdl_id],
        'local_device_uuid'  : request.local_device_id.device_uuid.uuid,
        'remote_device_uuid' : request.remote_device_id.device_uuid.uuid or None,
        'created_at'         : now,
        'updated_at'         : now,
    }]


    def callback(session : Session) -> Tuple[bool, List[Dict]]:
        stmt = insert(AppModel).values(app_data)
        stmt = stmt.on_conflict_do_update(
            index_elements=[AppModel.app_uuid],
            set_=dict(
                app_status         = stmt.excluded.app_status,
                app_type           = stmt.excluded.app_type,
                server_app_id      = stmt.excluded.server_app_id,
                client_app_id      = stmt.excluded.client_app_id,
                backing_qkdl_uuid  = stmt.excluded.backing_qkdl_uuid,
                local_device_uuid  = stmt.excluded.local_device_uuid,
                remote_device_uuid = stmt.excluded.remote_device_uuid,
                updated_at         = stmt.excluded.updated_at,
            )
        )
        stmt = stmt.returning(AppModel.created_at, AppModel.updated_at)
        created_at,updated_at = session.execute(stmt).fetchone()
        updated = updated_at > created_at

        return updated

    updated = run_transaction(sessionmaker(bind=db_engine), callback)
    context_id = json_context_id(context_uuid)
    app_id = json_app_id(app_uuid, context_id=context_id)
    #event_type = EventTypeEnum.EVENTTYPE_UPDATE if updated else EventTypeEnum.EVENTTYPE_CREATE
    #notify_event_app(messagebroker, event_type, app_id)
    #notify_event_context(messagebroker, EventTypeEnum.EVENTTYPE_UPDATE, context_id)
    return AppId(**app_id)



def app_get_by_server(db_engine : Engine, request : str) -> App:
    def callback(session : Session) -> Optional[Dict]:
        obj : Optional[AppModel] = session.query(AppModel)\
            .filter_by(server_app_id=request).one_or_none()
        return None if obj is None else obj.dump()
    obj = run_transaction(sessionmaker(bind=db_engine), callback)
    if obj is None:
        raise NotFoundException('No app match found for', request)
    return App(**obj)



"""
def device_delete(db_engine : Engine, messagebroker : MessageBroker, request : DeviceId) -> Empty:
    device_uuid = device_get_uuid(request, allow_random=False)
    def callback(session : Session) -> Tuple[bool, List[Dict]]:
        query = session.query(TopologyDeviceModel)
        query = query.filter_by(device_uuid=device_uuid)
        topology_device_list : List[TopologyDeviceModel] = query.all()
        topology_ids = [obj.topology.dump_id() for obj in topology_device_list]
        num_deleted = session.query(DeviceModel).filter_by(device_uuid=device_uuid).delete()
        return num_deleted > 0, topology_ids
    deleted, updated_topology_ids = run_transaction(sessionmaker(bind=db_engine), callback)
    device_id = json_device_id(device_uuid)
    if deleted:
        notify_event_device(messagebroker, EventTypeEnum.EVENTTYPE_REMOVE, device_id)

        context_ids  : Dict[str, Dict] = dict()
        topology_ids : Dict[str, Dict] = dict()
        for topology_id in updated_topology_ids:
            topology_uuid = topology_id['topology_uuid']['uuid']
            topology_ids[topology_uuid] = topology_id
            context_id = topology_id['context_id']
            context_uuid = context_id['context_uuid']['uuid']
            context_ids[context_uuid] = context_id

        for topology_id in topology_ids.values():
            notify_event_topology(messagebroker, EventTypeEnum.EVENTTYPE_UPDATE, topology_id)

        for context_id in context_ids.values():
            notify_event_context(messagebroker, EventTypeEnum.EVENTTYPE_UPDATE, context_id)

    return Empty()

def device_select(db_engine : Engine, request : DeviceFilter) -> DeviceList:
    device_uuids = [
        device_get_uuid(device_id, allow_random=False)
        for device_id in request.device_ids.device_ids
    ]
    dump_params = dict(
        include_endpoints   =request.include_endpoints,
        include_config_rules=request.include_config_rules,
        include_components  =request.include_components,
    )
    def callback(session : Session) -> List[Dict]:
        query = session.query(DeviceModel)
        if request.include_endpoints   : query = query.options(selectinload(DeviceModel.endpoints))
        if request.include_config_rules: query = query.options(selectinload(DeviceModel.config_rules))
        #if request.include_components  : query = query.options(selectinload(DeviceModel.components))
        obj_list : List[DeviceModel] = query.filter(DeviceModel.device_uuid.in_(device_uuids)).all()
        return [obj.dump(**dump_params) for obj in obj_list]
    devices = run_transaction(sessionmaker(bind=db_engine), callback)
    return DeviceList(devices=devices)
"""
