Skip to content
Snippets Groups Projects
ConnectionModel.py 6.35 KiB
Newer Older
# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging, operator
from typing import Dict, List, Optional, Set, Tuple, Union
from common.orm.Database import Database
from common.orm.backend.Tools import key_to_str
from common.orm.fields.ForeignKeyField import ForeignKeyField
from common.orm.fields.IntegerField import IntegerField
from common.orm.fields.PrimaryKeyField import PrimaryKeyField
from common.orm.fields.StringField import StringField
from common.orm.model.Model import Model
from common.orm.HighLevel import get_object, get_or_create_object, get_related_objects, update_or_create_object
from common.proto.context_pb2 import EndPointId
from .EndPointModel import EndPointModel
from .ServiceModel import ServiceModel
from .Tools import remove_dict_key

LOGGER = logging.getLogger(__name__)

class PathModel(Model): # pylint: disable=abstract-method
    pk = PrimaryKeyField()

    def dump(self) -> List[Dict]:
        db_path_hop_pks = self.references(PathHopModel)
        path_hops = [PathHopModel(self.database, pk).dump(include_position=True) for pk,_ in db_path_hop_pks]
        path_hops = sorted(path_hops, key=operator.itemgetter('position'))
        return [remove_dict_key(path_hop, 'position') for path_hop in path_hops]

class PathHopModel(Model): # pylint: disable=abstract-method
    pk = PrimaryKeyField()
    path_fk = ForeignKeyField(PathModel)
    position = IntegerField(min_value=0, required=True)
    endpoint_fk = ForeignKeyField(EndPointModel)

    def dump(self, include_position=True) -> Dict: # pylint: disable=arguments-differ
        db_endpoint : EndPointModel = EndPointModel(self.database, self.endpoint_fk)
        result = db_endpoint.dump_id()
        if include_position: result['position'] = self.position
        return result

class ConnectionModel(Model):
    pk = PrimaryKeyField()
    connection_uuid = StringField(required=True, allow_empty=False)
    service_fk = ForeignKeyField(ServiceModel, required=False)
    path_fk = ForeignKeyField(PathModel, required=True)

    def dump_id(self) -> Dict:
        return {
            'connection_uuid': {'uuid': self.connection_uuid},
        }

    def dump_path_hops_endpoint_ids(self) -> List[Dict]:
        return PathModel(self.database, self.path_fk).dump()

    def dump_sub_service_ids(self) -> List[Dict]:
        from .RelationModels import ConnectionSubServiceModel # pylint: disable=import-outside-toplevel
        db_sub_services = get_related_objects(self, ConnectionSubServiceModel, 'sub_service_fk')
        return [db_sub_service.dump_id() for db_sub_service in sorted(db_sub_services, key=operator.attrgetter('pk'))]

    def dump(self, include_path=True, include_sub_service_ids=True) -> Dict: # pylint: disable=arguments-differ
        result = {'connection_id': self.dump_id()}
        if self.service_fk is not None:
            result['service_id'] = ServiceModel(self.database, self.service_fk).dump_id()
        if include_path: result['path_hops_endpoint_ids'] = self.dump_path_hops_endpoint_ids()
        if include_sub_service_ids: result['sub_service_ids'] = self.dump_sub_service_ids()
        return result

def set_path_hop(
        database : Database, db_path : PathModel, position : int, db_endpoint : EndPointModel
    ) -> Tuple[PathHopModel, bool]:

    str_path_hop_key = key_to_str([db_path.pk, db_endpoint.pk], separator=':')
    result : Tuple[PathHopModel, bool] = update_or_create_object(database, PathHopModel, str_path_hop_key, {
        'path_fk': db_path, 'position': position, 'endpoint_fk': db_endpoint})
    db_path_hop, updated = result
    return db_path_hop, updated

def delete_path_hop(
        database : Database, db_path : PathModel, db_path_hop_pk : str
    ) -> None:

    db_path_hop : Optional[PathHopModel] = get_object(database, PathHopModel, db_path_hop_pk, raise_if_not_found=False)
    if db_path_hop is None: return
    db_path_hop.delete()

def delete_all_path_hops(
        database : Database, db_path : PathHopModel
    ) -> None:

    db_path_hop_pks = db_path.references(PathHopModel)
    for pk,_ in db_path_hop_pks: PathHopModel(database, pk).delete()

def set_path(
        database : Database, connection_uuid : str, raw_endpoint_ids : List[EndPointId], path_name : str = ''
    ) -> List[Union[PathModel, PathHopModel]]:

    str_path_key = connection_uuid if len(path_name) == 0 else key_to_str([connection_uuid, path_name], separator=':')
    result : Tuple[PathModel, bool] = get_or_create_object(database, PathModel, str_path_key)
    db_path, created = result

    db_path_hop_pks : Set[str] = set(map(operator.itemgetter(0), db_path.references(PathHopModel)))
    db_objects : List[Tuple[Union[PathModel, PathHopModel], bool]] = [db_path]

    for position,endpoint_id in enumerate(raw_endpoint_ids):
        endpoint_uuid                  = endpoint_id.endpoint_uuid.uuid
        endpoint_device_uuid           = endpoint_id.device_id.device_uuid.uuid
        endpoint_topology_uuid         = endpoint_id.topology_id.topology_uuid.uuid
        endpoint_topology_context_uuid = endpoint_id.topology_id.context_id.context_uuid.uuid

        str_endpoint_key = key_to_str([endpoint_device_uuid, endpoint_uuid])
        if len(endpoint_topology_context_uuid) > 0 and len(endpoint_topology_uuid) > 0:
            str_topology_key = key_to_str([endpoint_topology_context_uuid, endpoint_topology_uuid])
            str_endpoint_key = key_to_str([str_endpoint_key, str_topology_key], separator=':')

        db_endpoint : EndPointModel = get_object(database, EndPointModel, str_endpoint_key)

        result : Tuple[PathHopModel, bool] = set_path_hop(database, db_path, position, db_endpoint)
        db_path_hop, updated = result
        db_objects.append(db_path_hop)
        db_path_hop_pks.discard(db_path_hop.instance_key)

    for db_path_hop_pk in db_path_hop_pks: delete_path_hop(database, db_path, db_path_hop_pk)

    return db_objects