Skip to content
Snippets Groups Projects
Database.py 3.5 KiB
Newer Older
from typing import Tuple, List

from sqlalchemy import MetaData
Carlos Manso's avatar
Carlos Manso committed
from sqlalchemy.orm import Session, joinedload
from context.service.database.Base import Base
import logging
from common.orm.backend.Tools import key_to_str

from common.rpc_method_wrapper.ServiceExceptions import NotFoundException

LOGGER = logging.getLogger(__name__)


class Database(Session):
    def __init__(self, session):
        super().__init__()
        self.session = session

    def get_all(self, model):
        result = []
        with self.session() as session:
            for entry in session.query(model).all():
                result.append(entry)

        return result

    def create_or_update(self, model):
        with self.session() as session:
            att = getattr(model, model.main_pk_name())
Carlos Manso's avatar
Carlos Manso committed
            obj = self.get_object(type(model), att)

            filt = {model.main_pk_name(): att}
Carlos Manso's avatar
Carlos Manso committed
            t_model = type(model)
            found = session.query(t_model).filter_by(**filt).one_or_none()
            if found:
                found = True
            else:
                found = False

            session.merge(model)
            session.commit()
Carlos Manso's avatar
Carlos Manso committed

            obj = self.get_object(t_model, att)

        return model, found

    def create(self, model):
        with self.session() as session:
            session.add(model)
            session.commit()
        return model

    def remove(self, model, filter_d):
        model_t = type(model)
        with self.session() as session:
            session.query(model_t).filter_by(**filter_d).delete()
            session.commit()

    def clear(self):
        with self.session() as session:
            engine = session.get_bind()
        Base.metadata.drop_all(engine)
        Base.metadata.create_all(engine)

    def dump_by_table(self):
        with self.session() as session:
            engine = session.get_bind()
        meta = MetaData()
        meta.reflect(engine)
        result = {}

        for table in meta.sorted_tables:
            result[table.name] = [dict(row) for row in engine.execute(table.select())]
        LOGGER.info(result)
        return result

    def dump_all(self):
        with self.session() as session:
            engine = session.get_bind()
        meta = MetaData()
        meta.reflect(engine)
        result = []

        for table in meta.sorted_tables:
            for row in engine.execute(table.select()):
                result.append((table.name, dict(row)))
        LOGGER.info(result)

        return result

    def get_object(self, model_class: Base, main_key: str, raise_if_not_found=False):
        filt = {model_class.main_pk_name(): main_key}
        with self.session() as session:
            get = session.query(model_class).filter_by(**filt).one_or_none()

            if not get:
                if raise_if_not_found:
                    raise NotFoundException(model_class.__name__.replace('Model', ''), main_key)

            return get

Carlos Manso's avatar
Carlos Manso committed
    def get_or_create(self, model_class: Base, key_parts: List[str], filt=None) -> Tuple[Base, bool]:
        str_key = key_to_str(key_parts)
Carlos Manso's avatar
Carlos Manso committed
        if not filt:
            filt = {model_class.main_pk_name(): key_parts}
        with self.session() as session:
            get = session.query(model_class).filter_by(**filt).one_or_none()
            if get:
                return get, False
            else:
                obj = model_class()
                setattr(obj, model_class.main_pk_name(), str_key)
                session.add(obj)
                session.commit()
                return obj, True