diff --git a/src/interdomain/service/topology_abstractor/DltRecorder.py b/src/interdomain/service/topology_abstractor/DltRecorder.py index c5660e43d55bf2fe52a859c75acd5b09d6bdaa0f..418a53612852fd7d2a33b2a09b95e059e0bbcccf 100644 --- a/src/interdomain/service/topology_abstractor/DltRecorder.py +++ b/src/interdomain/service/topology_abstractor/DltRecorder.py @@ -15,6 +15,7 @@ import logging import threading import asyncio +import time from typing import Dict, Optional from common.Constants import DEFAULT_CONTEXT_NAME, DEFAULT_TOPOLOGY_NAME, INTERDOMAIN_TOPOLOGY_NAME, ServiceNameEnum @@ -47,6 +48,12 @@ class DLTRecorder(threading.Thread): self.context_event_collector = EventsCollector(self.context_client) self.topology_cache: Dict[str, TopologyId] = {} + # Queues for each event type + self.create_event_queue = asyncio.Queue() + self.update_event_queue = asyncio.Queue() + self.remove_event_queue = asyncio.Queue() + + def stop(self): self.terminate.set() @@ -61,27 +68,29 @@ class DLTRecorder(threading.Thread): tasks = [] batch_timeout = 1 # Time in seconds to wait before processing whatever tasks are available + last_task_time = time.time() while not self.terminate.is_set(): event = self.context_event_collector.get_event(timeout=0.1) - if event is None: - continue - LOGGER.info('Processing Event({:s})...'.format(grpc_message_to_json_string(event))) - task = asyncio.create_task(self.update_record(event)) - tasks.append(task) - LOGGER.debug('Task for event scheduled.') - # Limit the number of concurrent tasks - # If we have enough tasks or it's time to process them - if len(tasks) >= 10 or (tasks and len(tasks) > 0 and await asyncio.sleep(batch_timeout)): + if event: + LOGGER.info('Processing Event({:s})...'.format(grpc_message_to_json_string(event))) + task = asyncio.create_task(self.update_record(event)) + tasks.append(task) + LOGGER.debug('Task for event scheduled.') + + # Update the last task time since we've added a new task + last_task_time = time.time() + + # Check if it's time to process the tasks or if we have enough tasks + if tasks and (len(tasks) >= 10 or (time.time() - last_task_time >= batch_timeout)): try: await asyncio.gather(*tasks) except Exception as e: LOGGER.error(f"Error while processing tasks: {e}") finally: tasks = [] # Clear the list after processing - await asyncio.gather(*tasks) - tasks = [] # Clear the list after processing - # Process any remaining tasks when stopping + + # Process any remaining tasks when stopping if tasks: try: await asyncio.gather(*tasks) @@ -91,10 +100,6 @@ class DLTRecorder(threading.Thread): self.context_event_collector.stop() self.context_client.close() - #def create_topologies(self): - #topology_uuids = [DEFAULT_TOPOLOGY_NAME, INTERDOMAIN_TOPOLOGY_NAME] - #create_missing_topologies(self.context_client, ADMIN_CONTEXT_ID, topology_uuids) - async def update_record(self, event: EventTypes) -> None: dlt_record_sender = DltRecordSender(self.context_client) await dlt_record_sender.initialize() # Ensure DltRecordSender is initialized asynchronously