# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging, operator
from typing import Dict, List, Optional, Tuple, Type, Union
from common.orm.HighLevel import get_object, get_or_create_object, update_or_create_object
from common.orm.backend.Tools import key_to_str
from common.proto.context_pb2 import Constraint
from common.tools.grpc.Tools import grpc_message_to_json_string
from .EndPointModel import EndPointModel
from .Tools import fast_hasher
from sqlalchemy import Column, ForeignKey, String, Float, CheckConstraint, Integer, Boolean, Enum
from sqlalchemy.dialects.postgresql import UUID
from context.service.database.models._Base import Base
import enum

LOGGER = logging.getLogger(__name__)

def remove_dict_key(dictionary : Dict, key : str):
    dictionary.pop(key, None)
    return dictionary

class ConstraintsModel(Base): # pylint: disable=abstract-method
    __tablename__ = 'Constraints'
    constraints_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)

    @staticmethod
    def main_pk_name():
        return 'constraints_uuid'


    def dump(self, constraints) -> List[Dict]:
        constraints = sorted(constraints, key=operator.itemgetter('position'))
        return [remove_dict_key(constraint, 'position') for constraint in constraints]


class ConstraintCustomModel(Base): # pylint: disable=abstract-method
    __tablename__ = 'ConstraintCustom'
    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
    constraint_type = Column(String, nullable=False)
    constraint_value = Column(String, nullable=False)

    @staticmethod
    def main_pk_name():
        return 'constraint_uuid'


    def dump(self) -> Dict: # pylint: disable=arguments-differ
        return {'custom': {'constraint_type': self.constraint_type, 'constraint_value': self.constraint_value}}


Union_ConstraintEndpoint = Union[
    'ConstraintEndpointLocationGpsPositionModel', 'ConstraintEndpointLocationRegionModel',
    'ConstraintEndpointPriorityModel'
]

class ConstraintEndpointLocationRegionModel(Model): # pylint: disable=abstract-method
    endpoint_fk = ForeignKeyField(EndPointModel)
    region = StringField(required=True, allow_empty=False)

    def dump(self) -> Dict: # pylint: disable=arguments-differ
        json_endpoint_id = EndPointModel(self.database, self.endpoint_fk).dump_id()
        return {'endpoint_location': {'endpoint_id': json_endpoint_id, 'location': {'region': self.region}}}

# def dump_endpoint_id(endpoint_constraint: Union_ConstraintEndpoint):
#     db_endpoints_pks = list(endpoint_constraint.references(EndPointModel))
#     num_endpoints = len(db_endpoints_pks)
#     if num_endpoints != 1:
#         raise Exception('Wrong number({:d}) of associated Endpoints with constraint'.format(num_endpoints))
#     db_endpoint = EndPointModel(endpoint_constraint.database, db_endpoints_pks[0])
#     return db_endpoint.dump_id()


class ConstraintEndpointLocationRegionModel(Base): # pylint: disable=abstract-method
    __tablename__ = 'ConstraintEndpointLocationRegion'
    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
    endpoint_uuid = Column(UUID(as_uuid=False), ForeignKey("EndPoint.endpoint_uuid"))
    region = Column(String, nullable=False)

    @staticmethod
    def main_pk_name():
        return 'constraint_uuid'

    def dump(self, endpoint) -> Dict: # pylint: disable=arguments-differ
        return {'endpoint_location': {'endpoint_id': endpoint.dump_id(), 'region': self.region}}

    def dump(self) -> Dict: # pylint: disable=arguments-differ
        gps_position = {'latitude': self.latitude, 'longitude': self.longitude}
        json_endpoint_id = EndPointModel(self.database, self.endpoint_fk).dump_id()
        return {'endpoint_location': {'endpoint_id': json_endpoint_id, 'location': {'gps_position': gps_position}}}

class ConstraintEndpointPriorityModel(Model): # pylint: disable=abstract-method
    endpoint_fk = ForeignKeyField(EndPointModel)
    priority = IntegerField(required=True, min_value=0)

    def dump(self) -> Dict: # pylint: disable=arguments-differ
        json_endpoint_id = EndPointModel(self.database, self.endpoint_fk).dump_id()
        return {'endpoint_priority': {'endpoint_id': json_endpoint_id, 'priority': self.priority}}

class ConstraintEndpointLocationGpsPositionModel(Base): # pylint: disable=abstract-method
    __tablename__ = 'ConstraintEndpointLocationGpsPosition'
    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
    endpoint_uuid = Column(UUID(as_uuid=False), ForeignKey("EndPoint.endpoint_uuid"))
    latitude = Column(Float, CheckConstraint('latitude > -90.0 AND latitude < 90.0'), nullable=False)
    longitude = Column(Float, CheckConstraint('longitude > -90.0 AND longitude < 90.0'), nullable=False)

    def dump(self, endpoint) -> Dict: # pylint: disable=arguments-differ
        gps_position = {'latitude': self.latitude, 'longitude': self.longitude}
        return {'endpoint_location': {'endpoint_id': endpoint.dump_id(), 'gps_position': gps_position}}


class ConstraintEndpointPriorityModel(Base): # pylint: disable=abstract-method
    __tablename__ = 'ConstraintEndpointPriority'
    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
    endpoint_uuid = Column(UUID(as_uuid=False), ForeignKey("EndPoint.endpoint_uuid"))
    # endpoint_fk = ForeignKeyField(EndPointModel)
    # priority = FloatField(required=True)
    priority = Column(Float, nullable=False)
    @staticmethod
    def main_pk_name():
        return 'constraint_uuid'

    def dump(self, endpoint) -> Dict: # pylint: disable=arguments-differ
        return {'endpoint_priority': {'endpoint_id': endpoint.dump_id(), 'priority': self.priority}}


class ConstraintSlaAvailabilityModel(Base): # pylint: disable=abstract-method
    __tablename__ = 'ConstraintSlaAvailability'
    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
    # num_disjoint_paths = IntegerField(required=True, min_value=1)
    num_disjoint_paths = Column(Integer, CheckConstraint('num_disjoint_paths > 1'), nullable=False)
    # all_active = BooleanField(required=True)
    all_active = Column(Boolean, nullable=False)
    @staticmethod
    def main_pk_name():
        return 'constraint_uuid'

    def dump(self) -> Dict: # pylint: disable=arguments-differ
        return {'sla_availability': {'num_disjoint_paths': self.num_disjoint_paths, 'all_active': self.all_active}}

# enum values should match name of field in ConstraintModel
class ConstraintKindEnum(enum.Enum):
    CUSTOM                        = 'custom'
    ENDPOINT_LOCATION_REGION      = 'ep_loc_region'
    ENDPOINT_LOCATION_GPSPOSITION = 'ep_loc_gpspos'
    ENDPOINT_PRIORITY             = 'ep_priority'
    SLA_AVAILABILITY              = 'sla_avail'

Union_SpecificConstraint = Union[
    ConstraintCustomModel, ConstraintEndpointLocationRegionModel, ConstraintEndpointLocationGpsPositionModel,
    ConstraintEndpointPriorityModel, ConstraintSlaAvailabilityModel,
]

class ConstraintModel(Base): # pylint: disable=abstract-method
    __tablename__ = 'Constraint'
    # pk = PrimaryKeyField()
    # constraints_fk = ForeignKeyField(ConstraintsModel)
    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
    constraints_uuid = Column(UUID(as_uuid=False), ForeignKey("Constraints.constraints_uuid"), primary_key=True)
    # kind = EnumeratedField(ConstraintKindEnum)
    kind = Column(Enum(ConstraintKindEnum, create_constraint=False, native_enum=False))
    # position = IntegerField(min_value=0, required=True)
    position = Column(Integer, CheckConstraint('position >= 0'), nullable=False)
    # constraint_custom_fk        = ForeignKeyField(ConstraintCustomModel, required=False)
    constraint_custom = Column(UUID(as_uuid=False), ForeignKey("ConstraintCustom.constraint_uuid"))
    # constraint_ep_loc_region_fk = ForeignKeyField(ConstraintEndpointLocationRegionModel, required=False)
    constraint_ep_loc_region = Column(UUID(as_uuid=False), ForeignKey("ConstraintEndpointLocationRegion.constraint_uuid"))
    # constraint_ep_loc_gpspos_fk = ForeignKeyField(ConstraintEndpointLocationGpsPositionModel, required=False)
    constraint_ep_loc_gpspos = Column(UUID(as_uuid=False), ForeignKey("ConstraintEndpointLocationGpsPosition.constraint_uuid"))
    # constraint_ep_priority_fk   = ForeignKeyField(ConstraintEndpointPriorityModel, required=False)
    constraint_ep_priority = Column(UUID(as_uuid=False), ForeignKey("ConstraintEndpointPriority.constraint_uuid"),)
    # constraint_sla_avail_fk     = ForeignKeyField(ConstraintSlaAvailabilityModel, required=False)
    constraint_sla_avail = Column(UUID(as_uuid=False), ForeignKey("ConstraintSlaAvailability.constraint_uuid"))

    @staticmethod
    def main_pk_name():
        return 'constraint_uuid'

    # def delete(self) -> None:
    #     field_name = 'constraint_{:s}_fk'.format(str(self.kind.value))
    #     specific_fk_value : Optional[ForeignKeyField] = getattr(self, field_name, None)
    #     if specific_fk_value is None:
    #         raise Exception('Unable to find constraint key for field_name({:s})'.format(field_name))
    #     specific_fk_class = getattr(ConstraintModel, field_name, None)
    #     foreign_model_class : Model = specific_fk_class.foreign_model
    #     super().delete()
    #     get_object(self.database, foreign_model_class, str(specific_fk_value)).delete()

    def dump(self, include_position=True) -> Dict: # pylint: disable=arguments-differ
        field_name = 'constraint_{:s}'.format(str(self.kind.value))
        specific_fk_value = getattr(self, field_name, None)
        if specific_fk_value is None:
            raise Exception('Unable to find constraint key for field_name({:s})'.format(field_name))
        specific_fk_class = getattr(ConstraintModel, field_name, None)
        foreign_model_class: Base = specific_fk_class.foreign_model
        constraint: Union_SpecificConstraint = get_object(self.database, foreign_model_class, str(specific_fk_value))
        result = constraint.dump()
        if include_position:
            result['position'] = self.position
        return result

Tuple_ConstraintSpecs = Tuple[Type, str, Dict, ConstraintKindEnum]

def parse_constraint_custom(grpc_constraint) -> Tuple_ConstraintSpecs:
    constraint_class = ConstraintCustomModel
    str_constraint_id = grpc_constraint.custom.constraint_type
    constraint_data = {
        'constraint_type' : grpc_constraint.custom.constraint_type,
        'constraint_value': grpc_constraint.custom.constraint_value,
    }
    return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.CUSTOM

def parse_constraint_endpoint_location(db_endpoint, grpc_constraint) -> Tuple_ConstraintSpecs:
    grpc_endpoint_id = grpc_constraint.endpoint_location.endpoint_id
    # str_endpoint_key, db_endpoint = get_endpoint(database, grpc_endpoint_id)

    str_constraint_id = db_endpoint.endpoint_uuid
    constraint_data = {'endpoint_fk': db_endpoint}

    grpc_location = grpc_constraint.endpoint_location.location
    location_kind = str(grpc_location.WhichOneof('location'))
    if location_kind == 'region':
        constraint_class = ConstraintEndpointLocationRegionModel
        constraint_data.update({'region': grpc_location.region})
        return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.ENDPOINT_LOCATION_REGION
    elif location_kind == 'gps_position':
        constraint_class = ConstraintEndpointLocationGpsPositionModel
        gps_position = grpc_location.gps_position
        constraint_data.update({'latitude': gps_position.latitude, 'longitude': gps_position.longitude})
        return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.ENDPOINT_LOCATION_GPSPOSITION
    else:
        MSG = 'Location kind {:s} in Constraint of kind endpoint_location is not implemented: {:s}'
        raise NotImplementedError(MSG.format(location_kind, grpc_message_to_json_string(grpc_constraint)))

def parse_constraint_endpoint_priority(db_endpoint, grpc_constraint) -> Tuple_ConstraintSpecs:
    grpc_endpoint_id = grpc_constraint.endpoint_priority.endpoint_id
    # str_endpoint_key, db_endpoint = get_endpoint(database, grpc_endpoint_id)

    constraint_class = ConstraintEndpointPriorityModel
    str_constraint_id = db_endpoint.endpoint_uuid
    priority = grpc_constraint.endpoint_priority.priority
    constraint_data = {'endpoint_fk': db_endpoint, 'priority': priority}

    return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.ENDPOINT_PRIORITY

def parse_constraint_sla_availability(grpc_constraint) -> Tuple_ConstraintSpecs:
    constraint_class = ConstraintSlaAvailabilityModel
    str_constraint_id = ''
    constraint_data = {
        'num_disjoint_paths' : grpc_constraint.sla_availability.num_disjoint_paths,
        'all_active': grpc_constraint.sla_availability.all_active,
    }
    return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.SLA_AVAILABILITY

CONSTRAINT_PARSERS = {
    'custom'            : parse_constraint_custom,
    'endpoint_location' : parse_constraint_endpoint_location,
    'endpoint_priority' : parse_constraint_endpoint_priority,
    'sla_availability'  : parse_constraint_sla_availability,
}

Union_ConstraintModel = Union[
    ConstraintCustomModel, ConstraintEndpointLocationGpsPositionModel, ConstraintEndpointLocationRegionModel,
    ConstraintEndpointPriorityModel, ConstraintSlaAvailabilityModel
]

# def set_constraint(
#     db_constraints : ConstraintsModel, grpc_constraint : Constraint, position : int
# ) -> Tuple[Union_ConstraintModel, bool]:
#     grpc_constraint_kind = str(grpc_constraint.WhichOneof('constraint'))
#
#     parser = CONSTRAINT_PARSERS.get(grpc_constraint_kind)
#     if parser is None:
#         raise NotImplementedError('Constraint of kind {:s} is not implemented: {:s}'.format(
#             grpc_constraint_kind, grpc_message_to_json_string(grpc_constraint)))
#
#     # create specific constraint
#     constraint_class, str_constraint_id, constraint_data, constraint_kind = parser(database, grpc_constraint)
#     str_constraint_key_hash = fast_hasher(':'.join([constraint_kind.value, str_constraint_id]))
#     str_constraint_key = key_to_str([db_constraints.pk, str_constraint_key_hash], separator=':')
#     result : Tuple[Union_ConstraintModel, bool] = update_or_create_object(
#         database, constraint_class, str_constraint_key, constraint_data)
#     db_specific_constraint, updated = result
#
#     # create generic constraint
#     constraint_fk_field_name = 'constraint_{:s}_fk'.format(constraint_kind.value)
#     constraint_data = {
#         'constraints_fk': db_constraints, 'position': position, 'kind': constraint_kind,
#         constraint_fk_field_name: db_specific_constraint
#     }
#     result : Tuple[ConstraintModel, bool] = update_or_create_object(
#         database, ConstraintModel, str_constraint_key, constraint_data)
#     db_constraint, updated = result
#
#     return db_constraint, updated
#
# def set_constraints(
#     database : Database, db_parent_pk : str, constraints_name : str, grpc_constraints
# ) -> List[Tuple[Union[ConstraintsModel, ConstraintModel], bool]]:
#
#     str_constraints_key = key_to_str([db_parent_pk, constraints_name], separator=':')
#     result : Tuple[ConstraintsModel, bool] = get_or_create_object(database, ConstraintsModel, str_constraints_key)
#     db_constraints, created = result
#
#     db_objects = [(db_constraints, created)]
#
#     for position,grpc_constraint in enumerate(grpc_constraints):
#         result : Tuple[ConstraintModel, bool] = set_constraint(
#             database, db_constraints, grpc_constraint, position)
#         db_constraint, updated = result
#         db_objects.append((db_constraint, updated))
#
#     return db_objects
def set_constraint(
    database : Database, db_constraints : ConstraintsModel, grpc_constraint : Constraint, position : int
) -> Tuple[Union_ConstraintModel, bool]:
    grpc_constraint_kind = str(grpc_constraint.WhichOneof('constraint'))

    parser = CONSTRAINT_PARSERS.get(grpc_constraint_kind)
    if parser is None:
        raise NotImplementedError('Constraint of kind {:s} is not implemented: {:s}'.format(
            grpc_constraint_kind, grpc_message_to_json_string(grpc_constraint)))

    # create specific constraint
    constraint_class, str_constraint_id, constraint_data, constraint_kind = parser(database, grpc_constraint)
    str_constraint_key_hash = fast_hasher(':'.join([constraint_kind.value, str_constraint_id]))
    str_constraint_key = key_to_str([db_constraints.pk, str_constraint_key_hash], separator=':')
    result : Tuple[Union_ConstraintModel, bool] = update_or_create_object(
        database, constraint_class, str_constraint_key, constraint_data)
    db_specific_constraint, updated = result

    # create generic constraint
    constraint_fk_field_name = 'constraint_{:s}_fk'.format(constraint_kind.value)
    constraint_data = {
        'constraints_fk': db_constraints, 'position': position, 'kind': constraint_kind,
        constraint_fk_field_name: db_specific_constraint
    }
    result : Tuple[ConstraintModel, bool] = update_or_create_object(
        database, ConstraintModel, str_constraint_key, constraint_data)
    db_constraint, updated = result

    return db_constraint, updated

def set_constraints(
    database : Database, db_parent_pk : str, constraints_name : str, grpc_constraints
) -> List[Tuple[Union[ConstraintsModel, ConstraintModel], bool]]:

    str_constraints_key = key_to_str([constraints_name, db_parent_pk], separator=':')
    result : Tuple[ConstraintsModel, bool] = get_or_create_object(database, ConstraintsModel, str_constraints_key)
    db_constraints, created = result

    db_objects = [(db_constraints, created)]

    for position,grpc_constraint in enumerate(grpc_constraints):
        result : Tuple[ConstraintModel, bool] = set_constraint(
            database, db_constraints, grpc_constraint, position)
        db_constraint, updated = result
        db_objects.append((db_constraint, updated))

    return db_objects