Commit 81378162 authored by Pelayo Torres's avatar Pelayo Torres
Browse files

immRep logic

parent b8632d8f
Loading
Loading
Loading
Loading
Loading
+28 −10
Original line number Original line Diff line number Diff line
@@ -4,16 +4,22 @@ import secrets
import rfc3987
import rfc3987
from capif_events.models.event_subscription import EventSubscription  # noqa: E501
from capif_events.models.event_subscription import EventSubscription  # noqa: E501
from flask import current_app, Response
from flask import current_app, Response
from datetime import datetime, timedelta, timezone
from datetime import datetime, timezone
import asyncio


from .auth_manager import AuthManager
from .auth_manager import AuthManager
from .resources import Resource
from .resources import Resource
from .responses import internal_server_error, not_found_error, make_response, bad_request_error
from .responses import internal_server_error, not_found_error, make_response, bad_request_error
from ..util import serialize_clean_camel_case, clean_empty, dict_to_camel_case
from ..util import serialize_clean_camel_case, clean_empty, dict_to_camel_case
from .notifications import Notifications




class EventSubscriptionsOperations(Resource):
class EventSubscriptionsOperations(Resource):


    def __init__(self):
        super().__init__()
        self.notifications = Notifications()

    def __check_subscriber_id(self, subscriber_id):
    def __check_subscriber_id(self, subscriber_id):
        mycol_invoker= self.db.get_col_by_name(self.db.invoker_collection)
        mycol_invoker= self.db.get_col_by_name(self.db.invoker_collection)
        mycol_provider= self.db.get_col_by_name(self.db.provider_collection)
        mycol_provider= self.db.get_col_by_name(self.db.provider_collection)
@@ -58,7 +64,7 @@ class EventSubscriptionsOperations(Resource):
                 return bad_request_error(detail="Bad Param", cause = f"Invalid eventFilter for event {event}", invalid_params=[{"param": "eventFilter", "reason": f"The eventFilter {invalid_filters} for event {event} are not applicable."}])
                 return bad_request_error(detail="Bad Param", cause = f"Invalid eventFilter for event {event}", invalid_params=[{"param": "eventFilter", "reason": f"The eventFilter {invalid_filters} for event {event} are not applicable."}])
        return None
        return None
    
    
    def __check_event_req(self, event_subscription):
    def __check_event_req(self, event_subscription, subscription_id=None):
        current_app.logger.debug("Checking event requirement.")
        current_app.logger.debug("Checking event requirement.")
        expired_at = None
        expired_at = None
        if event_subscription.event_req.mon_dur:
        if event_subscription.event_req.mon_dur:
@@ -79,6 +85,14 @@ class EventSubscriptionsOperations(Resource):
                cause="Periodic notification method selected but repPeriod not provided",
                cause="Periodic notification method selected but repPeriod not provided",
                invalid_params=[{"param": "repPeriod", "reason": "Periodic notification method selected but repPeriod not provided"}]
                invalid_params=[{"param": "repPeriod", "reason": "Periodic notification method selected but repPeriod not provided"}]
            )
            )
        
        if event_subscription.event_req.imm_rep and subscription_id is not None:
            current_app.logger.debug("Sending immediate notification")
            notifications_col = self.db.get_col_by_name(self.db.notifications_col)
            result = notifications_col.find({"subscription_id": subscription_id})
            for notification in result:
                asyncio.run(self.notifications.send(notification["url"], notification["notification"]))

        return expired_at
        return expired_at


    def __init__(self):
    def __init__(self):
@@ -209,6 +223,7 @@ class EventSubscriptionsOperations(Resource):
    def put_event(self, event_subscription, subscriber_id, subscription_id):
    def put_event(self, event_subscription, subscriber_id, subscription_id):
        try:
        try:
            mycol = self.db.get_col_by_name(self.db.event_collection)
            mycol = self.db.get_col_by_name(self.db.event_collection)
            notifications_col = self.db.get_col_by_name(self.db.notifications_col)


            current_app.logger.debug("Updating event subscription")
            current_app.logger.debug("Updating event subscription")
            
            
@@ -224,6 +239,8 @@ class EventSubscriptionsOperations(Resource):
            if  isinstance(result, Response):
            if  isinstance(result, Response):
                return result
                return result
            
            
            current_app.logger.debug(event_subscription)
            expired_at = None
            if EventSubscription.return_supp_feat_dict(event_subscription.supported_features)["EnhancedEventReport"]:
            if EventSubscription.return_supp_feat_dict(event_subscription.supported_features)["EnhancedEventReport"]:
                if event_subscription.event_filters:
                if event_subscription.event_filters:
                    current_app.logger.debug(event_subscription.event_filters)
                    current_app.logger.debug(event_subscription.event_filters)
@@ -231,8 +248,7 @@ class EventSubscriptionsOperations(Resource):
                    if isinstance(result, Response):
                    if isinstance(result, Response):
                        return result
                        return result
                if event_subscription.event_req:
                if event_subscription.event_req:
                    current_app.logger.debug(event_subscription.event_req)
                    expired_at = self.__check_event_req(event_subscription, subscription_id)
                    expired_at = self.__check_event_req(event_subscription)
                    if isinstance(expired_at, Response):
                    if isinstance(expired_at, Response):
                        return result
                        return result


@@ -252,6 +268,7 @@ class EventSubscriptionsOperations(Resource):
            body["created_at"] = eventdescription.get("created_at", datetime.now(timezone.utc))
            body["created_at"] = eventdescription.get("created_at", datetime.now(timezone.utc))
            body["expire_at"] = expired_at if expired_at else eventdescription.get("expire_at", None)
            body["expire_at"] = expired_at if expired_at else eventdescription.get("expire_at", None)


            notifications_col.delete_many({"subscription_id": subscription_id})
            mycol.replace_one(my_query, body)
            mycol.replace_one(my_query, body)
            current_app.logger.debug("Event subscription updated from database")
            current_app.logger.debug("Event subscription updated from database")


@@ -269,6 +286,7 @@ class EventSubscriptionsOperations(Resource):
    def patch_event(self, event_subscription, subscriber_id, subscription_id):
    def patch_event(self, event_subscription, subscriber_id, subscription_id):
        try:
        try:
            mycol = self.db.get_col_by_name(self.db.event_collection)
            mycol = self.db.get_col_by_name(self.db.event_collection)
            notifications_col = self.db.get_col_by_name(self.db.notifications_col)


            current_app.logger.debug("Patching event subscription")
            current_app.logger.debug("Patching event subscription")


@@ -284,7 +302,8 @@ class EventSubscriptionsOperations(Resource):
            if eventdescription is None:
            if eventdescription is None:
                current_app.logger.error("Event subscription not found")
                current_app.logger.error("Event subscription not found")
                return not_found_error(detail="Event subscription not exist", cause="Event API subscription id not found")
                return not_found_error(detail="Event subscription not exist", cause="Event API subscription id not found")

            current_app.logger.debug(event_subscription)
            expired_at = None
            if EventSubscription.return_supp_feat_dict(eventdescription.get("supported_features"))["EnhancedEventReport"]:
            if EventSubscription.return_supp_feat_dict(eventdescription.get("supported_features"))["EnhancedEventReport"]:
                if event_subscription.events and event_subscription.event_filters:
                if event_subscription.events and event_subscription.event_filters:
                    result = self.__check_event_filters(event_subscription.events, clean_empty(event_subscription.to_dict()["event_filters"]))
                    result = self.__check_event_filters(event_subscription.events, clean_empty(event_subscription.to_dict()["event_filters"]))
@@ -292,13 +311,12 @@ class EventSubscriptionsOperations(Resource):
                    result = self.__check_event_filters(event_subscription.events, eventdescription.get("event_filters"))
                    result = self.__check_event_filters(event_subscription.events, eventdescription.get("event_filters"))
                elif event_subscription.events is None and event_subscription.event_filters:
                elif event_subscription.events is None and event_subscription.event_filters:
                    result = self.__check_event_filters(eventdescription.get("events"), clean_empty(event_subscription.to_dict()["event_filters"]))
                    result = self.__check_event_filters(eventdescription.get("events"), clean_empty(event_subscription.to_dict()["event_filters"]))

                if  isinstance(result, Response):
                if  isinstance(result, Response):
                    return result
                    return result
                
                
                if event_subscription.event_req:
                if event_subscription.event_req:
                    current_app.logger.debug(event_subscription.event_req)
                    updated_data = EventSubscription.from_dict(dict_to_camel_case({**eventdescription, **clean_empty(event_subscription.to_dict())}))
                    expired_at = self.__check_event_req(event_subscription)
                    expired_at = self.__check_event_req(updated_data, subscription_id)
                    if isinstance(expired_at, Response):
                    if isinstance(expired_at, Response):
                        return result
                        return result
                    else:
                    else:
@@ -308,8 +326,8 @@ class EventSubscriptionsOperations(Resource):
                    return result
                    return result


            body = clean_empty(event_subscription.to_dict())
            body = clean_empty(event_subscription.to_dict())
            if expired_at:
            body["expire_at"] = expired_at
            body["expire_at"] = expired_at
            notifications_col.delete_many({"subscription_id": subscription_id})
            document = mycol.update_one(my_query, {"$set":body})
            document = mycol.update_one(my_query, {"$set":body})
            document = mycol.find_one(my_query)
            document = mycol.find_one(my_query)
            current_app.logger.debug("Event subscription patched from database")
            current_app.logger.debug("Event subscription patched from database")
+1 −1
Original line number Original line Diff line number Diff line
@@ -4,4 +4,4 @@ redis==4.5.4
aiohttp == 3.10.5
aiohttp == 3.10.5
async-timeout == 4.0.3
async-timeout == 4.0.3
pyyaml == 6.0.2
pyyaml == 6.0.2
python_dateutil >= 2.6.0
python_dateutil == 2.9.0
+0 −9
Original line number Original line Diff line number Diff line
@@ -8,13 +8,6 @@ from config import Config
import aiohttp
import aiohttp
import asyncio
import asyncio


# Celery Configuration
# celery = Celery(
#     "notifications",
#     broker=os.environ.get("CELERY_BROKER_URL", "redis://redis:6379/0"),
#     backend=os.environ.get("CELERY_RESULT_BACKEND", "redis://redis:6379/0")
# )

celery = Celery(
celery = Celery(
    "notifications",
    "notifications",
    broker=f"redis://{os.getenv("REDIS_HOST")}:{os.getenv("REDIS_PORT")}/0",
    broker=f"redis://{os.getenv("REDIS_HOST")}:{os.getenv("REDIS_PORT")}/0",
@@ -108,7 +101,6 @@ async def send(url, data):


@celery.task(name="celery.tasks.check_notifications_collection")
@celery.task(name="celery.tasks.check_notifications_collection")
def my_periodic_task():
def my_periodic_task():
    # print("Checking notifications collection...")
    while True:
    while True:
        try:
        try:
            notification_data = notifications_col.find_one_and_delete(
            notification_data = notifications_col.find_one_and_delete(
@@ -126,4 +118,3 @@ def my_periodic_task():
        except Exception as e:
        except Exception as e:
            print(f"Error sending notification: {e}")
            print(f"Error sending notification: {e}")
    # print("Finished processing notifications.")