# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Iterator, NamedTuple, Tuple import grpc, logging from common.tests.MockMessageBroker import MockMessageBroker from common.tools.grpc.Tools import grpc_message_to_json_string from context.proto.context_pb2 import Empty, TeraFlowController from dlt.connector.proto.dlt_pb2 import ( DltPeerStatus, DltPeerStatusList, DltRecord, DltRecordEvent, DltRecordId, DltRecordOperationEnum, DltRecordStatus, DltRecordSubscription, DltRecordTypeEnum) from dlt.connector.proto.dlt_pb2_grpc import DltServiceServicer LOGGER = logging.getLogger(__name__) DltRecordKey = Tuple[str, DltRecordOperationEnum, str] # domain_uuid, operation, record_uuid DltRecordDict = Dict[DltRecordKey, DltRecord] # dlt_record_key => dlt_record class MockServicerImpl_Dlt(DltServiceServicer): def __init__(self): LOGGER.info('[__init__] Creating Servicer...') self.records : DltRecordDict = {} self.msg_broker = MockMessageBroker() LOGGER.info('[__init__] Servicer Created') def RecordToDlt(self, request : DltRecord, context : grpc.ServicerContext) -> DltRecordStatus: LOGGER.info('[RecordToDlt] request={:s}'.format(grpc_message_to_json_string(request))) operation = request.operation domain_uuid = request.record_id.domain_uuid record_uuid = request.record_id.record_uuid #if operation == def GetFromDlt(self, request : DltRecordId, context : grpc.ServicerContext) -> DltRecord: LOGGER.info('[GetFromDlt] request={:s}'.format(grpc_message_to_json_string(request))) def SubscribeToDlt(self, request: DltRecordSubscription, context : grpc.ServicerContext) -> Iterator[DltRecordEvent]: LOGGER.info('[SubscribeToDlt] request={:s}'.format(grpc_message_to_json_string(request))) for message in self.msg_broker.consume({TOPIC_CONTEXT}): yield ContextEvent(**json.loads(message.content)) def GetDltStatus(self, request : TeraFlowController, context : grpc.ServicerContext) -> DltPeerStatus: LOGGER.info('[GetDltStatus] request={:s}'.format(grpc_message_to_json_string(request))) def GetDltPeers(self, request : Empty, context : grpc.ServicerContext) -> DltPeerStatusList: LOGGER.info('[GetDltPeers] request={:s}'.format(grpc_message_to_json_string(request))) LOGGER.info('[__init__] Servicer Created') # ----- Common ----------------------------------------------------------------------------------------------------- def _set(self, request, container_name, entry_uuid, entry_id_field_name, topic_name): exists = has_entry(self.database, container_name, entry_uuid) entry = set_entry(self.database, container_name, entry_uuid, request) event_type = EventTypeEnum.EVENTTYPE_UPDATE if exists else EventTypeEnum.EVENTTYPE_CREATE entry_id = getattr(entry, entry_id_field_name) dict_entry_id = grpc_message_to_json(entry_id) notify_event(self.msg_broker, topic_name, event_type, {entry_id_field_name: dict_entry_id}) return entry_id def _del(self, request, container_name, entry_uuid, entry_id_field_name, topic_name, grpc_context): empty = del_entry(grpc_context, self.database, container_name, entry_uuid) event_type = EventTypeEnum.EVENTTYPE_REMOVE dict_entry_id = grpc_message_to_json(request) notify_event(self.msg_broker, topic_name, event_type, {entry_id_field_name: dict_entry_id}) return empty # ----- Context ---------------------------------------------------------------------------------------------------- def ListContextIds(self, request: Empty, context : grpc.ServicerContext) -> ContextIdList: LOGGER.info('[ListContextIds] request={:s}'.format(grpc_message_to_json_string(request))) return ContextIdList(context_ids=[context.context_id for context in get_entries(self.database, 'context')]) def ListContexts(self, request: Empty, context : grpc.ServicerContext) -> ContextList: LOGGER.info('[ListContexts] request={:s}'.format(grpc_message_to_json_string(request))) return ContextList(contexts=get_entries(self.database, 'context')) def GetContext(self, request: ContextId, context : grpc.ServicerContext) -> Context: LOGGER.info('[GetContext] request={:s}'.format(grpc_message_to_json_string(request))) return get_entry(context, self.database, 'context', request.context_uuid.uuid) def SetContext(self, request: Context, context : grpc.ServicerContext) -> ContextId: LOGGER.info('[SetContext] request={:s}'.format(grpc_message_to_json_string(request))) return self._set(request, 'context', request.context_uuid.uuid, 'context_id', TOPIC_CONTEXT) def RemoveContext(self, request: ContextId, context : grpc.ServicerContext) -> Empty: LOGGER.info('[RemoveContext] request={:s}'.format(grpc_message_to_json_string(request))) return self._del(request, 'context', request.context_uuid.uuid, 'context_id', TOPIC_CONTEXT, context) def GetContextEvents(self, request: Empty, context : grpc.ServicerContext) -> Iterator[ContextEvent]: LOGGER.info('[GetContextEvents] request={:s}'.format(grpc_message_to_json_string(request))) for message in self.msg_broker.consume({TOPIC_CONTEXT}): yield ContextEvent(**json.loads(message.content))