diff --git a/src/context/service/database/Device.py b/src/context/service/database/Device.py index 579e7631ec4b4ab4fa78887b196dd7684fe6f93b..d816d3c5f06410bdaef9b4d698666c06bc5aa0ac 100644 --- a/src/context/service/database/Device.py +++ b/src/context/service/database/Device.py @@ -153,7 +153,7 @@ def device_set(db_engine : Engine, messagebroker : MessageBroker, request : Devi if controller_uuid is not None: device_data[0]['controller_uuid'] = controller_uuid - def callback(session : Session) -> bool: + def callback(session : Session) -> Tuple[bool, List[Dict]]: stmt = insert(DeviceModel).values(device_data) stmt = stmt.on_conflict_do_update( index_elements=[DeviceModel.device_uuid], @@ -188,29 +188,77 @@ def device_set(db_engine : Engine, messagebroker : MessageBroker, request : Devi if not updated or len(related_topologies) > 1: # Only update topology-device relations when device is created (not updated) or when endpoints are # modified (len(related_topologies) > 1). - session.execute(insert(TopologyDeviceModel).values(related_topologies).on_conflict_do_nothing( + stmt = insert(TopologyDeviceModel).values(related_topologies) + stmt = stmt.on_conflict_do_nothing( index_elements=[TopologyDeviceModel.topology_uuid, TopologyDeviceModel.device_uuid] - )) + ) + stmt = stmt.returning(TopologyDeviceModel.topology_uuid) + topology_uuids = session.execute(stmt).fetchall() + + LOGGER.warning('topology_uuids={:s}'.format(str(topology_uuids))) + if len(topology_uuids) > 0: + query = session.query(TopologyModel) + query = query.filter(TopologyModel.topology_uuid.in_(topology_uuids)) + device_topologies : List[TopologyModel] = query.all() + device_topology_ids = [obj.dump_id() for obj in device_topologies] + else: + device_topology_ids = [] changed_config_rules = upsert_config_rules(session, config_rules, device_uuid=device_uuid) - return updated or updated_endpoints or changed_config_rules + return updated or updated_endpoints or changed_config_rules, device_topology_ids - updated = run_transaction(sessionmaker(bind=db_engine), callback) + updated, device_topology_ids = run_transaction(sessionmaker(bind=db_engine), callback) device_id = json_device_id(device_uuid) event_type = EventTypeEnum.EVENTTYPE_UPDATE if updated else EventTypeEnum.EVENTTYPE_CREATE notify_event_device(messagebroker, event_type, device_id) + + context_ids : Dict[str, Dict] = dict() + topology_ids : Dict[str, Dict] = dict() + for topology_id in device_topology_ids: + topology_uuid = topology_id['topology_uuid']['uuid'] + topology_ids[topology_uuid] = topology_id + context_id = topology_id['context_id'] + context_uuid = context_id['context_uuid']['uuid'] + context_ids[context_uuid] = context_id + + for topology_id in topology_ids.values(): + notify_event_topology(messagebroker, EventTypeEnum.EVENTTYPE_UPDATE, topology_id) + + for context_id in context_ids.values(): + notify_event_context(messagebroker, EventTypeEnum.EVENTTYPE_UPDATE, context_id) + return DeviceId(**device_id) def device_delete(db_engine : Engine, messagebroker : MessageBroker, request : DeviceId) -> Empty: device_uuid = device_get_uuid(request, allow_random=False) - def callback(session : Session) -> bool: + def callback(session : Session) -> Tuple[bool, List[Dict]]: + query = session.query(TopologyDeviceModel) + query = query.filter_by(device_uuid=device_uuid) + topology_device_list : List[TopologyDeviceModel] = query.all() + topology_ids = [obj.topology.dump_id() for obj in topology_device_list] num_deleted = session.query(DeviceModel).filter_by(device_uuid=device_uuid).delete() - return num_deleted > 0 - deleted = run_transaction(sessionmaker(bind=db_engine), callback) + return num_deleted > 0, topology_ids + deleted, updated_topology_ids = run_transaction(sessionmaker(bind=db_engine), callback) device_id = json_device_id(device_uuid) if deleted: notify_event_device(messagebroker, EventTypeEnum.EVENTTYPE_REMOVE, device_id) + + context_ids : Dict[str, Dict] = dict() + topology_ids : Dict[str, Dict] = dict() + for topology_id in updated_topology_ids: + topology_uuid = topology_id['topology_uuid']['uuid'] + topology_ids[topology_uuid] = topology_id + context_id = topology_id['context_id'] + context_uuid = context_id['context_uuid']['uuid'] + context_ids[context_uuid] = context_id + + for topology_id in topology_ids.values(): + notify_event_topology(messagebroker, EventTypeEnum.EVENTTYPE_UPDATE, topology_id) + + for context_id in context_ids.values(): + notify_event_context(messagebroker, EventTypeEnum.EVENTTYPE_UPDATE, context_id) + return Empty() def device_select(db_engine : Engine, request : DeviceFilter) -> DeviceList: