# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum, json
from sqlalchemy import CheckConstraint, Column, Enum, ForeignKey, Integer, String
from sqlalchemy.dialects.postgresql import UUID
from typing import Dict
from ._Base import _Base

# Enum values should match name of field in ConstraintModel
class ConstraintKindEnum(enum.Enum):
    CUSTOM                        = 'custom'
    ENDPOINT_LOCATION_REGION      = 'ep_loc_region'
    ENDPOINT_LOCATION_GPSPOSITION = 'ep_loc_gpspos'
    ENDPOINT_PRIORITY             = 'ep_priority'
    SLA_AVAILABILITY              = 'sla_avail'

class ConstraintModel(_Base):
    __tablename__ = 'constraint'

    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True)
    service_uuid    = Column(ForeignKey('service.service_uuid', ondelete='CASCADE'), nullable=True)
    slice_uuid      = Column(ForeignKey('slice.slice_uuid',     ondelete='CASCADE'), nullable=True)
    position        = Column(Integer, nullable=False)
    kind            = Column(Enum(ConstraintKindEnum), nullable=False)
    data            = Column(String, nullable=False)

    __table_args__ = (
        CheckConstraint(position >= 0, name='check_position_value'),
        #UniqueConstraint('service_uuid', 'position', name='unique_per_service'),
        #UniqueConstraint('slice_uuid',   'position', name='unique_per_slice'  ),
    )

    def dump(self) -> Dict:
        return {self.kind.value: json.loads(self.data)}


#import logging, operator
#from typing import Dict, List, Optional, Tuple, Type, Union
#from common.orm.HighLevel import get_object, get_or_create_object, update_or_create_object
#from common.orm.backend.Tools import key_to_str
#from common.proto.context_pb2 import Constraint
#from common.tools.grpc.Tools import grpc_message_to_json_string
#from .EndPointModel import EndPointModel
#from .Tools import fast_hasher
#from sqlalchemy import Column, ForeignKey, String, Float, CheckConstraint, Integer, Boolean, Enum
#from sqlalchemy.dialects.postgresql import UUID
#from context.service.database.models._Base import Base
#import enum
#
#LOGGER = logging.getLogger(__name__)
#
#def remove_dict_key(dictionary : Dict, key : str):
#    dictionary.pop(key, None)
#    return dictionary
#
#class ConstraintsModel(Base): # pylint: disable=abstract-method
#    __tablename__ = 'Constraints'
#    constraints_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
#
#    @staticmethod
#    def main_pk_name():
#        return 'constraints_uuid'
#
#
#    def dump(self, constraints) -> List[Dict]:
#        constraints = sorted(constraints, key=operator.itemgetter('position'))
#        return [remove_dict_key(constraint, 'position') for constraint in constraints]
#
#
#class ConstraintCustomModel(Base): # pylint: disable=abstract-method
#    __tablename__ = 'ConstraintCustom'
#    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
#    constraint_type = Column(String, nullable=False)
#    constraint_value = Column(String, nullable=False)
#
#    @staticmethod
#    def main_pk_name():
#        return 'constraint_uuid'
#
#
#    def dump(self) -> Dict: # pylint: disable=arguments-differ
#        return {'custom': {'constraint_type': self.constraint_type, 'constraint_value': self.constraint_value}}
#
#
#Union_ConstraintEndpoint = Union[
#    'ConstraintEndpointLocationGpsPositionModel', 'ConstraintEndpointLocationRegionModel',
#    'ConstraintEndpointPriorityModel'
#]
#
#class ConstraintEndpointLocationRegionModel(Model): # pylint: disable=abstract-method
#    endpoint_fk = ForeignKeyField(EndPointModel)
#    region = StringField(required=True, allow_empty=False)
#
#    def dump(self) -> Dict: # pylint: disable=arguments-differ
#        json_endpoint_id = EndPointModel(self.database, self.endpoint_fk).dump_id()
#        return {'endpoint_location': {'endpoint_id': json_endpoint_id, 'location': {'region': self.region}}}
#
## def dump_endpoint_id(endpoint_constraint: Union_ConstraintEndpoint):
##     db_endpoints_pks = list(endpoint_constraint.references(EndPointModel))
##     num_endpoints = len(db_endpoints_pks)
##     if num_endpoints != 1:
##         raise Exception('Wrong number({:d}) of associated Endpoints with constraint'.format(num_endpoints))
##     db_endpoint = EndPointModel(endpoint_constraint.database, db_endpoints_pks[0])
##     return db_endpoint.dump_id()
#
#
#class ConstraintEndpointLocationRegionModel(Base): # pylint: disable=abstract-method
#    __tablename__ = 'ConstraintEndpointLocationRegion'
#    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
#    endpoint_uuid = Column(UUID(as_uuid=False), ForeignKey("EndPoint.endpoint_uuid"))
#    region = Column(String, nullable=False)
#
#    @staticmethod
#    def main_pk_name():
#        return 'constraint_uuid'
#
#    def dump(self, endpoint) -> Dict: # pylint: disable=arguments-differ
#        return {'endpoint_location': {'endpoint_id': endpoint.dump_id(), 'region': self.region}}
#
#    def dump(self) -> Dict: # pylint: disable=arguments-differ
#        gps_position = {'latitude': self.latitude, 'longitude': self.longitude}
#        json_endpoint_id = EndPointModel(self.database, self.endpoint_fk).dump_id()
#        return {'endpoint_location': {'endpoint_id': json_endpoint_id, 'location': {'gps_position': gps_position}}}
#
#class ConstraintEndpointPriorityModel(Model): # pylint: disable=abstract-method
#    endpoint_fk = ForeignKeyField(EndPointModel)
#    priority = IntegerField(required=True, min_value=0)
#
#    def dump(self) -> Dict: # pylint: disable=arguments-differ
#        json_endpoint_id = EndPointModel(self.database, self.endpoint_fk).dump_id()
#        return {'endpoint_priority': {'endpoint_id': json_endpoint_id, 'priority': self.priority}}
#
#class ConstraintEndpointLocationGpsPositionModel(Base): # pylint: disable=abstract-method
#    __tablename__ = 'ConstraintEndpointLocationGpsPosition'
#    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
#    endpoint_uuid = Column(UUID(as_uuid=False), ForeignKey("EndPoint.endpoint_uuid"))
#    latitude = Column(Float, CheckConstraint('latitude > -90.0 AND latitude < 90.0'), nullable=False)
#    longitude = Column(Float, CheckConstraint('longitude > -90.0 AND longitude < 90.0'), nullable=False)
#
#    def dump(self, endpoint) -> Dict: # pylint: disable=arguments-differ
#        gps_position = {'latitude': self.latitude, 'longitude': self.longitude}
#        return {'endpoint_location': {'endpoint_id': endpoint.dump_id(), 'gps_position': gps_position}}
#
#
#class ConstraintEndpointPriorityModel(Base): # pylint: disable=abstract-method
#    __tablename__ = 'ConstraintEndpointPriority'
#    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
#    endpoint_uuid = Column(UUID(as_uuid=False), ForeignKey("EndPoint.endpoint_uuid"))
#    # endpoint_fk = ForeignKeyField(EndPointModel)
#    # priority = FloatField(required=True)
#    priority = Column(Float, nullable=False)
#    @staticmethod
#    def main_pk_name():
#        return 'constraint_uuid'
#
#    def dump(self, endpoint) -> Dict: # pylint: disable=arguments-differ
#        return {'endpoint_priority': {'endpoint_id': endpoint.dump_id(), 'priority': self.priority}}
#
#
#class ConstraintSlaAvailabilityModel(Base): # pylint: disable=abstract-method
#    __tablename__ = 'ConstraintSlaAvailability'
#    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
#    # num_disjoint_paths = IntegerField(required=True, min_value=1)
#    num_disjoint_paths = Column(Integer, CheckConstraint('num_disjoint_paths > 1'), nullable=False)
#    # all_active = BooleanField(required=True)
#    all_active = Column(Boolean, nullable=False)
#    @staticmethod
#    def main_pk_name():
#        return 'constraint_uuid'
#
#    def dump(self) -> Dict: # pylint: disable=arguments-differ
#        return {'sla_availability': {'num_disjoint_paths': self.num_disjoint_paths, 'all_active': self.all_active}}
#
#Union_SpecificConstraint = Union[
#    ConstraintCustomModel, ConstraintEndpointLocationRegionModel, ConstraintEndpointLocationGpsPositionModel,
#    ConstraintEndpointPriorityModel, ConstraintSlaAvailabilityModel,
#]
#
#class ConstraintModel(Base): # pylint: disable=abstract-method
#    __tablename__ = 'Constraint'
#    # pk = PrimaryKeyField()
#    # constraints_fk = ForeignKeyField(ConstraintsModel)
#    constraint_uuid = Column(UUID(as_uuid=False), primary_key=True, unique=True)
#    constraints_uuid = Column(UUID(as_uuid=False), ForeignKey("Constraints.constraints_uuid"), primary_key=True)
#    # kind = EnumeratedField(ConstraintKindEnum)
#    kind = Column(Enum(ConstraintKindEnum, create_constraint=False, native_enum=False))
#    # position = IntegerField(min_value=0, required=True)
#    position = Column(Integer, CheckConstraint('position >= 0'), nullable=False)
#    # constraint_custom_fk        = ForeignKeyField(ConstraintCustomModel, required=False)
#    constraint_custom = Column(UUID(as_uuid=False), ForeignKey("ConstraintCustom.constraint_uuid"))
#    # constraint_ep_loc_region_fk = ForeignKeyField(ConstraintEndpointLocationRegionModel, required=False)
#    constraint_ep_loc_region = Column(UUID(as_uuid=False), ForeignKey("ConstraintEndpointLocationRegion.constraint_uuid"))
#    # constraint_ep_loc_gpspos_fk = ForeignKeyField(ConstraintEndpointLocationGpsPositionModel, required=False)
#    constraint_ep_loc_gpspos = Column(UUID(as_uuid=False), ForeignKey("ConstraintEndpointLocationGpsPosition.constraint_uuid"))
#    # constraint_ep_priority_fk   = ForeignKeyField(ConstraintEndpointPriorityModel, required=False)
#    constraint_ep_priority = Column(UUID(as_uuid=False), ForeignKey("ConstraintEndpointPriority.constraint_uuid"),)
#    # constraint_sla_avail_fk     = ForeignKeyField(ConstraintSlaAvailabilityModel, required=False)
#    constraint_sla_avail = Column(UUID(as_uuid=False), ForeignKey("ConstraintSlaAvailability.constraint_uuid"))
#
#    @staticmethod
#    def main_pk_name():
#        return 'constraint_uuid'
#
#    # def delete(self) -> None:
#    #     field_name = 'constraint_{:s}_fk'.format(str(self.kind.value))
#    #     specific_fk_value : Optional[ForeignKeyField] = getattr(self, field_name, None)
#    #     if specific_fk_value is None:
#    #         raise Exception('Unable to find constraint key for field_name({:s})'.format(field_name))
#    #     specific_fk_class = getattr(ConstraintModel, field_name, None)
#    #     foreign_model_class : Model = specific_fk_class.foreign_model
#    #     super().delete()
#    #     get_object(self.database, foreign_model_class, str(specific_fk_value)).delete()
#
#    def dump(self, include_position=True) -> Dict: # pylint: disable=arguments-differ
#        field_name = 'constraint_{:s}'.format(str(self.kind.value))
#        specific_fk_value = getattr(self, field_name, None)
#        if specific_fk_value is None:
#            raise Exception('Unable to find constraint key for field_name({:s})'.format(field_name))
#        specific_fk_class = getattr(ConstraintModel, field_name, None)
#        foreign_model_class: Base = specific_fk_class.foreign_model
#        constraint: Union_SpecificConstraint = get_object(self.database, foreign_model_class, str(specific_fk_value))
#        result = constraint.dump()
#        if include_position:
#            result['position'] = self.position
#        return result
#
#Tuple_ConstraintSpecs = Tuple[Type, str, Dict, ConstraintKindEnum]
#
#def parse_constraint_custom(grpc_constraint) -> Tuple_ConstraintSpecs:
#    constraint_class = ConstraintCustomModel
#    str_constraint_id = grpc_constraint.custom.constraint_type
#    constraint_data = {
#        'constraint_type' : grpc_constraint.custom.constraint_type,
#        'constraint_value': grpc_constraint.custom.constraint_value,
#    }
#    return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.CUSTOM
#
#def parse_constraint_endpoint_location(db_endpoint, grpc_constraint) -> Tuple_ConstraintSpecs:
#    grpc_endpoint_id = grpc_constraint.endpoint_location.endpoint_id
#    # str_endpoint_key, db_endpoint = get_endpoint(database, grpc_endpoint_id)
#
#    str_constraint_id = db_endpoint.endpoint_uuid
#    constraint_data = {'endpoint_fk': db_endpoint}
#
#    grpc_location = grpc_constraint.endpoint_location.location
#    location_kind = str(grpc_location.WhichOneof('location'))
#    if location_kind == 'region':
#        constraint_class = ConstraintEndpointLocationRegionModel
#        constraint_data.update({'region': grpc_location.region})
#        return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.ENDPOINT_LOCATION_REGION
#    elif location_kind == 'gps_position':
#        constraint_class = ConstraintEndpointLocationGpsPositionModel
#        gps_position = grpc_location.gps_position
#        constraint_data.update({'latitude': gps_position.latitude, 'longitude': gps_position.longitude})
#        return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.ENDPOINT_LOCATION_GPSPOSITION
#    else:
#        MSG = 'Location kind {:s} in Constraint of kind endpoint_location is not implemented: {:s}'
#        raise NotImplementedError(MSG.format(location_kind, grpc_message_to_json_string(grpc_constraint)))
#
#def parse_constraint_endpoint_priority(db_endpoint, grpc_constraint) -> Tuple_ConstraintSpecs:
#    grpc_endpoint_id = grpc_constraint.endpoint_priority.endpoint_id
#    # str_endpoint_key, db_endpoint = get_endpoint(database, grpc_endpoint_id)
#
#    constraint_class = ConstraintEndpointPriorityModel
#    str_constraint_id = db_endpoint.endpoint_uuid
#    priority = grpc_constraint.endpoint_priority.priority
#    constraint_data = {'endpoint_fk': db_endpoint, 'priority': priority}
#
#    return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.ENDPOINT_PRIORITY
#
#def parse_constraint_sla_availability(grpc_constraint) -> Tuple_ConstraintSpecs:
#    constraint_class = ConstraintSlaAvailabilityModel
#    str_constraint_id = ''
#    constraint_data = {
#        'num_disjoint_paths' : grpc_constraint.sla_availability.num_disjoint_paths,
#        'all_active': grpc_constraint.sla_availability.all_active,
#    }
#    return constraint_class, str_constraint_id, constraint_data, ConstraintKindEnum.SLA_AVAILABILITY
#
#CONSTRAINT_PARSERS = {
#    'custom'            : parse_constraint_custom,
#    'endpoint_location' : parse_constraint_endpoint_location,
#    'endpoint_priority' : parse_constraint_endpoint_priority,
#    'sla_availability'  : parse_constraint_sla_availability,
#}
#
#Union_ConstraintModel = Union[
#    ConstraintCustomModel, ConstraintEndpointLocationGpsPositionModel, ConstraintEndpointLocationRegionModel,
#    ConstraintEndpointPriorityModel, ConstraintSlaAvailabilityModel
#]
#
## def set_constraint(
##     db_constraints : ConstraintsModel, grpc_constraint : Constraint, position : int
## ) -> Tuple[Union_ConstraintModel, bool]:
##     grpc_constraint_kind = str(grpc_constraint.WhichOneof('constraint'))
##
##     parser = CONSTRAINT_PARSERS.get(grpc_constraint_kind)
##     if parser is None:
##         raise NotImplementedError('Constraint of kind {:s} is not implemented: {:s}'.format(
##             grpc_constraint_kind, grpc_message_to_json_string(grpc_constraint)))
##
##     # create specific constraint
##     constraint_class, str_constraint_id, constraint_data, constraint_kind = parser(database, grpc_constraint)
##     str_constraint_key_hash = fast_hasher(':'.join([constraint_kind.value, str_constraint_id]))
##     str_constraint_key = key_to_str([db_constraints.pk, str_constraint_key_hash], separator=':')
##     result : Tuple[Union_ConstraintModel, bool] = update_or_create_object(
##         database, constraint_class, str_constraint_key, constraint_data)
##     db_specific_constraint, updated = result
##
##     # create generic constraint
##     constraint_fk_field_name = 'constraint_{:s}_fk'.format(constraint_kind.value)
##     constraint_data = {
##         'constraints_fk': db_constraints, 'position': position, 'kind': constraint_kind,
##         constraint_fk_field_name: db_specific_constraint
##     }
##     result : Tuple[ConstraintModel, bool] = update_or_create_object(
##         database, ConstraintModel, str_constraint_key, constraint_data)
##     db_constraint, updated = result
##
##     return db_constraint, updated
##
## def set_constraints(
##     database : Database, db_parent_pk : str, constraints_name : str, grpc_constraints
## ) -> List[Tuple[Union[ConstraintsModel, ConstraintModel], bool]]:
##
##     str_constraints_key = key_to_str([db_parent_pk, constraints_name], separator=':')
##     result : Tuple[ConstraintsModel, bool] = get_or_create_object(database, ConstraintsModel, str_constraints_key)
##     db_constraints, created = result
##
##     db_objects = [(db_constraints, created)]
##
##     for position,grpc_constraint in enumerate(grpc_constraints):
##         result : Tuple[ConstraintModel, bool] = set_constraint(
##             database, db_constraints, grpc_constraint, position)
##         db_constraint, updated = result
##         db_objects.append((db_constraint, updated))
##
##     return db_objects
#def set_constraint(
#    database : Database, db_constraints : ConstraintsModel, grpc_constraint : Constraint, position : int
#) -> Tuple[Union_ConstraintModel, bool]:
#    grpc_constraint_kind = str(grpc_constraint.WhichOneof('constraint'))
#
#    parser = CONSTRAINT_PARSERS.get(grpc_constraint_kind)
#    if parser is None:
#        raise NotImplementedError('Constraint of kind {:s} is not implemented: {:s}'.format(
#            grpc_constraint_kind, grpc_message_to_json_string(grpc_constraint)))
#
#    # create specific constraint
#    constraint_class, str_constraint_id, constraint_data, constraint_kind = parser(database, grpc_constraint)
#    str_constraint_key_hash = fast_hasher(':'.join([constraint_kind.value, str_constraint_id]))
#    str_constraint_key = key_to_str([db_constraints.pk, str_constraint_key_hash], separator=':')
#    result : Tuple[Union_ConstraintModel, bool] = update_or_create_object(
#        database, constraint_class, str_constraint_key, constraint_data)
#    db_specific_constraint, updated = result
#
#    # create generic constraint
#    constraint_fk_field_name = 'constraint_{:s}_fk'.format(constraint_kind.value)
#    constraint_data = {
#        'constraints_fk': db_constraints, 'position': position, 'kind': constraint_kind,
#        constraint_fk_field_name: db_specific_constraint
#    }
#    result : Tuple[ConstraintModel, bool] = update_or_create_object(
#        database, ConstraintModel, str_constraint_key, constraint_data)
#    db_constraint, updated = result
#
#    return db_constraint, updated
#
#def set_constraints(
#    database : Database, db_parent_pk : str, constraints_name : str, grpc_constraints
#) -> List[Tuple[Union[ConstraintsModel, ConstraintModel], bool]]:
#
#    str_constraints_key = key_to_str([constraints_name, db_parent_pk], separator=':')
#    result : Tuple[ConstraintsModel, bool] = get_or_create_object(database, ConstraintsModel, str_constraints_key)
#    db_constraints, created = result
#
#    db_objects = [(db_constraints, created)]
#
#    for position,grpc_constraint in enumerate(grpc_constraints):
#        result : Tuple[ConstraintModel, bool] = set_constraint(
#            database, db_constraints, grpc_constraint, position)
#        db_constraint, updated = result
#        db_objects.append((db_constraint, updated))
#
#    return db_objects