Skip to content
Snippets Groups Projects
Database.py 3.38 KiB
Newer Older
from typing import Tuple, List

from sqlalchemy import MetaData
from sqlalchemy.orm import Session
from context.service.database.Base import Base
import logging
from common.orm.backend.Tools import key_to_str

from common.rpc_method_wrapper.ServiceExceptions import NotFoundException

LOGGER = logging.getLogger(__name__)


class Database(Session):
    def __init__(self, session):
        super().__init__()
        self.session = session

    def get_all(self, model):
        result = []
        with self.session() as session:
            for entry in session.query(model).all():
                result.append(entry)

        return result

    def create_or_update(self, model):
        with self.session() as session:
            att = getattr(model, model.main_pk_name())
            filt = {model.main_pk_name(): att}
            found = session.query(type(model)).filter_by(**filt).one_or_none()
            if found:
                found = True
            else:
                found = False

            session.merge(model)
            session.commit()
        return model, found

    def create(self, model):
        with self.session() as session:
            session.add(model)
            session.commit()
        return model

    def remove(self, model, filter_d):
        model_t = type(model)
        with self.session() as session:
            session.query(model_t).filter_by(**filter_d).delete()
            session.commit()

    def clear(self):
        with self.session() as session:
            engine = session.get_bind()
        Base.metadata.drop_all(engine)
        Base.metadata.create_all(engine)

    def dump_by_table(self):
        with self.session() as session:
            engine = session.get_bind()
        meta = MetaData()
        meta.reflect(engine)
        result = {}

        for table in meta.sorted_tables:
            result[table.name] = [dict(row) for row in engine.execute(table.select())]
        LOGGER.info(result)
        return result

    def dump_all(self):
        with self.session() as session:
            engine = session.get_bind()
        meta = MetaData()
        meta.reflect(engine)
        result = []

        for table in meta.sorted_tables:
            for row in engine.execute(table.select()):
                result.append((table.name, dict(row)))
        LOGGER.info(result)

        return result

    def get_object(self, model_class: Base, main_key: str, raise_if_not_found=False):
        filt = {model_class.main_pk_name(): main_key}
        with self.session() as session:
            get = session.query(model_class).filter_by(**filt).one_or_none()

            if not get:
                if raise_if_not_found:
                    raise NotFoundException(model_class.__name__.replace('Model', ''), main_key)

            return get
    def get_or_create(self, model_class: Base, key_parts: List[str]
                      ) -> Tuple[Base, bool]:

        str_key = key_to_str(key_parts)
        filt = {model_class.main_pk_name(): key_parts}
        with self.session() as session:
            get = session.query(model_class).filter_by(**filt).one_or_none()
            if get:
                return get, False
            else:
                obj = model_class()
                setattr(obj, model_class.main_pk_name(), str_key)
                LOGGER.info(obj.dump())
                session.add(obj)
                session.commit()
                return obj, True