diff --git a/src/interdomain/service/topology_abstractor/DltRecorder.py b/src/interdomain/service/topology_abstractor/DltRecorder.py index 1c5661b60e20c9650a62c79d356db052e2c187d1..6f4ce5e590c0ea4f8aa3428ce2c99f7df8d69b07 100644 --- a/src/interdomain/service/topology_abstractor/DltRecorder.py +++ b/src/interdomain/service/topology_abstractor/DltRecorder.py @@ -4,7 +4,7 @@ from typing import Dict, Optional from common.Constants import DEFAULT_CONTEXT_NAME, DEFAULT_TOPOLOGY_NAME, INTERDOMAIN_TOPOLOGY_NAME, ServiceNameEnum from common.Settings import ENVVAR_SUFIX_SERVICE_HOST, ENVVAR_SUFIX_SERVICE_PORT_GRPC, find_environment_variables, get_env_var_name -from common.proto.context_pb2 import ContextEvent, ContextId, Device, DeviceEvent, DeviceId, EndPointId, Link, LinkId, LinkEvent, TopologyId, TopologyEvent +from common.proto.context_pb2 import ContextEvent, ContextId, Device, DeviceEvent, DeviceId, EndPointId, Link, LinkEvent, TopologyId, TopologyEvent from common.tools.context_queries.Context import create_context from common.tools.context_queries.Device import get_uuids_of_devices_in_topology from common.tools.context_queries.Topology import create_missing_topologies @@ -30,7 +30,7 @@ class DLTRecorder(threading.Thread): self.terminate = threading.Event() self.context_client = ContextClient() self.context_event_collector = EventsCollector(self.context_client) - self.topology_cache = {} + self.topology_cache: Dict[str, TopologyId] = {} def stop(self): self.terminate.set() @@ -56,15 +56,19 @@ class DLTRecorder(threading.Thread): create_missing_topologies(self.context_client, ADMIN_CONTEXT_ID, topology_uuids) def get_dlt_connector_client(self) -> Optional[DltConnectorClient]: - env_vars = find_environment_variables([ - get_env_var_name(ServiceNameEnum.DLT, ENVVAR_SUFIX_SERVICE_HOST), - get_env_var_name(ServiceNameEnum.DLT, ENVVAR_SUFIX_SERVICE_PORT_GRPC), - ]) - if len(env_vars) == 2: - dlt_connector_client = DltConnectorClient() - dlt_connector_client.connect() - return dlt_connector_client - return None + # Always enable DLT for testing + dlt_connector_client = DltConnectorClient() + dlt_connector_client.connect() + return dlt_connector_client + # env_vars = find_environment_variables([ + # get_env_var_name(ServiceNameEnum.DLT, ENVVAR_SUFIX_SERVICE_HOST), + # get_env_var_name(ServiceNameEnum.DLT, ENVVAR_SUFIX_SERVICE_PORT_GRPC), + # ]) + # if len(env_vars) == 2: + # dlt_connector_client = DltConnectorClient() + # dlt_connector_client.connect() + # return dlt_connector_client + # return None def update_record(self, event: EventTypes) -> None: dlt_connector_client = self.get_dlt_connector_client() @@ -106,7 +110,7 @@ class DLTRecorder(threading.Thread): topology_details = self.context_client.GetTopologyDetails(topology_id) topology_name = topology_details.name - self.topology_cache[topology_uuid] = topology_details + self.topology_cache[topology_uuid] = topology_id if ((context_uuid == DEFAULT_CONTEXT_NAME) or (context_name == DEFAULT_CONTEXT_NAME)) and \ (topology_uuid not in topology_uuids) and (topology_name not in topology_uuids): @@ -122,14 +126,16 @@ class DLTRecorder(threading.Thread): LOGGER.warning(MSG.format(*args)) def find_topology_for_device(self, device_id: DeviceId) -> Optional[TopologyId]: - for topology_id, details in self.topology_cache.items(): + for topology_uuid, topology_id in self.topology_cache.items(): + details = self.context_client.GetTopologyDetails(topology_id) for device in details.devices: if device.device_id == device_id: return topology_id return None def find_topology_for_link(self, link_id: LinkId) -> Optional[TopologyId]: - for topology_id, details in self.topology_cache.items(): + for topology_uuid, topology_id in self.topology_cache.items(): + details = self.context_client.GetTopologyDetails(topology_id) for link in details.links: if link.link_id == link_id: return topology_id