# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy_cockroachdb import run_transaction
from typing import Dict, List, Optional
from common.proto.context_pb2 import ContextId, Service, ServiceId, ServiceIdList, ServiceList
from common.rpc_method_wrapper.ServiceExceptions import InvalidArgumentException, NotFoundException
from context.service.database.models.ServiceModel import ServiceModel

def service_list_ids(db_engine : Engine, request : ContextId) -> ServiceIdList:
    context_uuid = request.context_uuid.uuid
    def callback(session : Session) -> List[Dict]:
        obj_list : List[ServiceModel] = session.query(ServiceModel).filter_by(context_uuid=context_uuid).all()
        #.options(selectinload(ContextModel.service)).filter_by(context_uuid=context_uuid).one_or_none()
        return [obj.dump_id() for obj in obj_list]
    return ServiceIdList(service_ids=run_transaction(sessionmaker(bind=db_engine), callback))

def service_list_objs(db_engine : Engine, request : ContextId) -> ServiceList:
    context_uuid = request.context_uuid.uuid
    def callback(session : Session) -> List[Dict]:
        obj_list : List[ServiceModel] = session.query(ServiceModel).filter_by(context_uuid=context_uuid).all()
        #.options(selectinload(ContextModel.service)).filter_by(context_uuid=context_uuid).one_or_none()
        return [obj.dump() for obj in obj_list]
    return ServiceList(services=run_transaction(sessionmaker(bind=db_engine), callback))

def service_get(db_engine : Engine, request : ServiceId) -> Service:
    context_uuid = request.context_id.context_uuid.uuid
    service_uuid = request.service_uuid.uuid

    def callback(session : Session) -> Optional[Dict]:
        obj : Optional[ServiceModel] = session.query(ServiceModel)\
            .filter_by(context_uuid=context_uuid, service_uuid=service_uuid).one_or_none()
        return None if obj is None else obj.dump()
    obj = run_transaction(sessionmaker(bind=db_engine), callback)
    if obj is None:
        obj_uuid = '{:s}/{:s}'.format(context_uuid, service_uuid)
        raise NotFoundException('Service', obj_uuid)
    return Service(**obj)

def service_set(db_engine : Engine, request : Service) -> bool:
    context_uuid = request.service_id.context_id.context_uuid.uuid
    service_uuid = request.service_id.service_uuid.uuid
    service_name = request.name

    for i,endpoint_id in enumerate(request.service_endpoint_ids):
        endpoint_context_uuid = endpoint_id.topology_id.context_id.context_uuid.uuid
        if len(endpoint_context_uuid) > 0 and context_uuid != endpoint_context_uuid:
            raise InvalidArgumentException(
                'request.service_endpoint_ids[{:d}].topology_id.context_id.context_uuid.uuid'.format(i),
                endpoint_context_uuid,
                ['should be == {:s}({:s})'.format('request.service_id.context_id.context_uuid.uuid', context_uuid)])


    def callback(session : Session) -> None:
        service_data = [{
            'context_uuid' : context_uuid,
            'service_uuid': service_uuid,
            'service_name': service_name,
            'created_at'   : time.time(),
        }]
        stmt = insert(ServiceModel).values(service_data)
        stmt = stmt.on_conflict_do_update(
            index_elements=[ServiceModel.context_uuid, ServiceModel.service_uuid],
            set_=dict(service_name = stmt.excluded.service_name)
        )
        session.execute(stmt)

    run_transaction(sessionmaker(bind=db_engine), callback)
    return False # TODO: improve and check if created/updated


#        # db_context : ContextModel = get_object(self.database, ContextModel, context_uuid)
#        db_context = session.query(ContextModel).filter_by(context_uuid=context_uuid).one_or_none()
#        # str_service_key = key_to_str([context_uuid, service_uuid])
#        constraints_result = self.set_constraints(service_uuid, 'constraints', request.service_constraints)
#        db_constraints = constraints_result[0][0]
#
#        config_rules = grpc_config_rules_to_raw(request.service_config.config_rules)
#        running_config_result = update_config(self.database, str_service_key, 'running', config_rules)
#        db_running_config = running_config_result[0][0]
#
#        result : Tuple[ServiceModel, bool] = update_or_create_object(self.database, ServiceModel, str_service_key, {
#            'context_fk'            : db_context,
#            'service_uuid'          : service_uuid,
#            'service_type'          : grpc_to_enum__service_type(request.service_type),
#            'service_constraints_fk': db_constraints,
#            'service_status'        : grpc_to_enum__service_status(request.service_status.service_status),
#            'service_config_fk'     : db_running_config,
#        })
#        db_service, updated = result
#
#        for i,endpoint_id in enumerate(request.service_endpoint_ids):
#            endpoint_uuid                  = endpoint_id.endpoint_uuid.uuid
#            endpoint_device_uuid           = endpoint_id.device_id.device_uuid.uuid
#            endpoint_topology_uuid         = endpoint_id.topology_id.topology_uuid.uuid
#            endpoint_topology_context_uuid = endpoint_id.topology_id.context_id.context_uuid.uuid
#
#            str_endpoint_key = key_to_str([endpoint_device_uuid, endpoint_uuid])
#            if len(endpoint_topology_context_uuid) > 0 and len(endpoint_topology_uuid) > 0:
#                str_topology_key = key_to_str([endpoint_topology_context_uuid, endpoint_topology_uuid])
#                str_endpoint_key = key_to_str([str_endpoint_key, str_topology_key], separator=':')
#
#            db_endpoint : EndPointModel = get_object(self.database, EndPointModel, str_endpoint_key)
#
#            str_service_endpoint_key = key_to_str([service_uuid, str_endpoint_key], separator='--')
#            result : Tuple[ServiceEndPointModel, bool] = get_or_create_object(
#                self.database, ServiceEndPointModel, str_service_endpoint_key, {
#                    'service_fk': db_service, 'endpoint_fk': db_endpoint})
#            #db_service_endpoint, service_endpoint_created = result
#
#        event_type = EventTypeEnum.EVENTTYPE_UPDATE if updated else EventTypeEnum.EVENTTYPE_CREATE
#        dict_service_id = db_service.dump_id()
#        notify_event(self.messagebroker, TOPIC_SERVICE, event_type, {'service_id': dict_service_id})
#        return ServiceId(**dict_service_id)
#    context_uuid = request.service_id.context_id.context_uuid.uuid
#    db_context : ContextModel = get_object(self.database, ContextModel, context_uuid)
#
#    for i,endpoint_id in enumerate(request.service_endpoint_ids):
#        endpoint_topology_context_uuid = endpoint_id.topology_id.context_id.context_uuid.uuid
#        if len(endpoint_topology_context_uuid) > 0 and context_uuid != endpoint_topology_context_uuid:
#            raise InvalidArgumentException(
#                'request.service_endpoint_ids[{:d}].topology_id.context_id.context_uuid.uuid'.format(i),
#                endpoint_topology_context_uuid,
#                ['should be == {:s}({:s})'.format(
#                    'request.service_id.context_id.context_uuid.uuid', context_uuid)])
#
#    service_uuid = request.service_id.service_uuid.uuid
#    str_service_key = key_to_str([context_uuid, service_uuid])
#
#    constraints_result = set_constraints(
#        self.database, str_service_key, 'service', request.service_constraints)
#    db_constraints = constraints_result[0][0]
#
#    running_config_rules = update_config(
#        self.database, str_service_key, 'service', request.service_config.config_rules)
#    db_running_config = running_config_rules[0][0]
#
#    result : Tuple[ServiceModel, bool] = update_or_create_object(self.database, ServiceModel, str_service_key, {
#        'context_fk'            : db_context,
#        'service_uuid'          : service_uuid,
#        'service_type'          : grpc_to_enum__service_type(request.service_type),
#        'service_constraints_fk': db_constraints,
#        'service_status'        : grpc_to_enum__service_status(request.service_status.service_status),
#        'service_config_fk'     : db_running_config,
#    })
#    db_service, updated = result
#
#    for i,endpoint_id in enumerate(request.service_endpoint_ids):
#        endpoint_uuid                  = endpoint_id.endpoint_uuid.uuid
#        endpoint_device_uuid           = endpoint_id.device_id.device_uuid.uuid
#        endpoint_topology_uuid         = endpoint_id.topology_id.topology_uuid.uuid
#        endpoint_topology_context_uuid = endpoint_id.topology_id.context_id.context_uuid.uuid
#
#        str_endpoint_key = key_to_str([endpoint_device_uuid, endpoint_uuid])
#        if len(endpoint_topology_context_uuid) > 0 and len(endpoint_topology_uuid) > 0:
#            str_topology_key = key_to_str([endpoint_topology_context_uuid, endpoint_topology_uuid])
#            str_endpoint_key = key_to_str([str_endpoint_key, str_topology_key], separator=':')
#
#        db_endpoint : EndPointModel = get_object(self.database, EndPointModel, str_endpoint_key)
#
#        str_service_endpoint_key = key_to_str([service_uuid, str_endpoint_key], separator='--')
#        result : Tuple[ServiceEndPointModel, bool] = get_or_create_object(
#            self.database, ServiceEndPointModel, str_service_endpoint_key, {
#                'service_fk': db_service, 'endpoint_fk': db_endpoint})
#        #db_service_endpoint, service_endpoint_created = result
#
#    event_type = EventTypeEnum.EVENTTYPE_UPDATE if updated else EventTypeEnum.EVENTTYPE_CREATE
#    dict_service_id = db_service.dump_id()
#    notify_event(self.messagebroker, TOPIC_SERVICE, event_type, {'service_id': dict_service_id})
#    return ServiceId(**dict_service_id)


#    def set_constraint(self, db_constraints: ConstraintsModel, grpc_constraint: Constraint, position: int
#    ) -> Tuple[Union_ConstraintModel, bool]:
#        with self.session() as session:
#
#            grpc_constraint_kind = str(grpc_constraint.WhichOneof('constraint'))
#
#            parser = CONSTRAINT_PARSERS.get(grpc_constraint_kind)
#            if parser is None:
#                raise NotImplementedError('Constraint of kind {:s} is not implemented: {:s}'.format(
#                    grpc_constraint_kind, grpc_message_to_json_string(grpc_constraint)))
#
#            # create specific constraint
#            constraint_class, str_constraint_id, constraint_data, constraint_kind = parser(grpc_constraint)
#            str_constraint_id = str(uuid.uuid4())
#            LOGGER.info('str_constraint_id: {}'.format(str_constraint_id))
#            # str_constraint_key_hash = fast_hasher(':'.join([constraint_kind.value, str_constraint_id]))
#            # str_constraint_key = key_to_str([db_constraints.pk, str_constraint_key_hash], separator=':')
#
#            # result : Tuple[Union_ConstraintModel, bool] = update_or_create_object(
#            #     database, constraint_class, str_constraint_key, constraint_data)
#            constraint_data[constraint_class.main_pk_name()] = str_constraint_id
#            db_new_constraint = constraint_class(**constraint_data)
#            result: Tuple[Union_ConstraintModel, bool] = self.database.create_or_update(db_new_constraint)
#            db_specific_constraint, updated = result
#
#            # create generic constraint
#            # constraint_fk_field_name = 'constraint_uuid'.format(constraint_kind.value)
#            constraint_data = {
#                'constraints_uuid': db_constraints.constraints_uuid, 'position': position, 'kind': constraint_kind
#            }
#
#            db_new_constraint = ConstraintModel(**constraint_data)
#            result: Tuple[Union_ConstraintModel, bool] = self.database.create_or_update(db_new_constraint)
#            db_constraint, updated = result
#
#            return db_constraint, updated
#
#    def set_constraints(self, service_uuid: str, constraints_name : str, grpc_constraints
#    ) -> List[Tuple[Union[ConstraintsModel, ConstraintModel], bool]]:
#        with self.session() as session:
#            # str_constraints_key = key_to_str([db_parent_pk, constraints_name], separator=':')
#            # result : Tuple[ConstraintsModel, bool] = get_or_create_object(database, ConstraintsModel, str_constraints_key)
#            result = session.query(ConstraintsModel).filter_by(constraints_uuid=service_uuid).one_or_none()
#            created = None
#            if result:
#                created = True
#            session.query(ConstraintsModel).filter_by(constraints_uuid=service_uuid).one_or_none()
#            db_constraints = ConstraintsModel(constraints_uuid=service_uuid)
#            session.add(db_constraints)
#
#            db_objects = [(db_constraints, created)]
#
#            for position,grpc_constraint in enumerate(grpc_constraints):
#                result : Tuple[ConstraintModel, bool] = self.set_constraint(
#                    db_constraints, grpc_constraint, position)
#                db_constraint, updated = result
#                db_objects.append((db_constraint, updated))
#
#            return db_objects

def service_delete(db_engine : Engine, request : ServiceId) -> bool:
    context_uuid = request.context_id.context_uuid.uuid
    service_uuid = request.service_uuid.uuid
    def callback(session : Session) -> bool:
        num_deleted = session.query(ServiceModel)\
            .filter_by(context_uuid=context_uuid, service_uuid=service_uuid).delete()
        return num_deleted > 0
    return run_transaction(sessionmaker(bind=db_engine), callback)

    # def delete(self) -> None:
    #     from .RelationModels import ServiceEndPointModel
    #     for db_service_endpoint_pk,_ in self.references(ServiceEndPointModel):
    #         ServiceEndPointModel(self.database, db_service_endpoint_pk).delete()
    #     super().delete()
    #     ConfigModel(self.database, self.service_config_fk).delete()
    #     ConstraintsModel(self.database, self.service_constraints_fk).delete()
