NatsBackendThread.py 2.61 KB
Newer Older
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio, nats, nats.errors, queue, threading
from common.message_broker.Message import Message

class NatsBackendThread(threading.Thread):
    def __init__(self, nats_uri : str) -> None:
        self._nats_uri = nats_uri
        self._event_loop = asyncio.get_event_loop()
        self._terminate = asyncio.Event()
        self._publish_queue = asyncio.Queue[Message]()
        super().__init__()

    def terminate(self) -> None:
        self._terminate.set()

    async def _run_publisher(self) -> None:
        client = await nats.connect(servers=[self._nats_uri])
        while not self._terminate.is_set():
            message : Message = await self._publish_queue.get()
            await client.publish(message.topic, message.content.encode('UTF-8'))
        await client.drain()

    def publish(self, topic_name : str, message_content : str) -> None:
        self._publish_queue.put_nowait(Message(topic_name, message_content))

    async def _run_subscriber(
        self, topic_name : str, timeout : float, out_queue : queue.Queue[Message], unsubscribe : threading.Event
    ) -> None:
        client = await nats.connect(servers=[self._nats_uri])
        subscription = await client.subscribe(topic_name)
        while not self._terminate.is_set() and not unsubscribe.is_set():
            try:
                message = await subscription.next_msg(timeout)
            except nats.errors.TimeoutError:
                continue
            out_queue.put(Message(message.subject, message.data.decode('UTF-8')))
        await subscription.unsubscribe()
        await client.drain()

    def subscribe(
        self, topic_name : str, timeout : float, out_queue : queue.Queue[Message], unsubscribe : threading.Event
    ) -> None:
        self._event_loop.create_task(self._run_subscriber(topic_name, timeout, out_queue, unsubscribe))

    def run(self) -> None:
        asyncio.set_event_loop(self._event_loop)
        self._event_loop.create_task(self._run_publisher())
        self._event_loop.run_until_complete(self._terminate.wait())