from typing import Tuple, List from sqlalchemy import MetaData from sqlalchemy.orm import Session from context.service.database.Base import Base import logging from common.orm.backend.Tools import key_to_str from common.rpc_method_wrapper.ServiceExceptions import NotFoundException LOGGER = logging.getLogger(__name__) class Database(Session): def __init__(self, session): super().__init__() self.session = session def get_all(self, model): result = [] with self.session() as session: for entry in session.query(model).all(): result.append(entry) return result def create_or_update(self, model): with self.session() as session: att = getattr(model, model.main_pk_name()) filt = {model.main_pk_name(): att} found = session.query(type(model)).filter_by(**filt).one_or_none() if found: found = True else: found = False session.merge(model) session.commit() return model, found def create(self, model): with self.session() as session: session.add(model) session.commit() return model def remove(self, model, filter_d): model_t = type(model) with self.session() as session: session.query(model_t).filter_by(**filter_d).delete() session.commit() def clear(self): with self.session() as session: engine = session.get_bind() Base.metadata.drop_all(engine) Base.metadata.create_all(engine) def dump_by_table(self): with self.session() as session: engine = session.get_bind() meta = MetaData() meta.reflect(engine) result = {} for table in meta.sorted_tables: result[table.name] = [dict(row) for row in engine.execute(table.select())] LOGGER.info(result) return result def dump_all(self): with self.session() as session: engine = session.get_bind() meta = MetaData() meta.reflect(engine) result = [] for table in meta.sorted_tables: for row in engine.execute(table.select()): result.append((table.name, dict(row))) LOGGER.info(result) return result def get_object(self, model_class: Base, main_key: str, raise_if_not_found=False): filt = {model_class.main_pk_name(): main_key} with self.session() as session: get = session.query(model_class).filter_by(**filt).one_or_none() if not get: if raise_if_not_found: raise NotFoundException(model_class.__name__.replace('Model', ''), main_key) return get def get_or_create(self, model_class: Base, key_parts: List[str] ) -> Tuple[Base, bool]: str_key = key_to_str(key_parts) filt = {model_class.main_pk_name(): key_parts} with self.session() as session: get = session.query(model_class).filter_by(**filt).one_or_none() if get: return get, False else: obj = model_class() setattr(obj, model_class.main_pk_name(), str_key) LOGGER.info(obj.dump()) session.add(obj) session.commit() return obj, True