Newer
Older
from sqlalchemy import MetaData
from sqlalchemy.orm import Session #, joinedload
from typing import Tuple #, List
from context.service.database._Base import _Base
#from common.orm.backend.Tools import key_to_str
from common.rpc_method_wrapper.ServiceExceptions import NotFoundException
LOGGER = logging.getLogger(__name__)
class Database(Session):
def __init__(self, session):
super().__init__()
self.session = session
result = []
with self.session() as session:
for entry in session.query(model).all():
result.append(entry)
return result
def create_or_update(self, model):
with self.session() as session:
att = getattr(model, model.main_pk_name())
filt = {model.main_pk_name(): att}
obj = session.query(t_model).filter_by(**filt).one_or_none()
if obj:
for key in obj.__table__.columns.keys():
setattr(obj, key, getattr(model, key))
session.add(model)
session.commit()
return model, found
def create(self, model):
with self.session() as session:
session.add(model)
session.commit()
return model
def remove(self, model, filter_d):
model_t = type(model)
with self.session() as session:
session.query(model_t).filter_by(**filter_d).delete()
session.commit()
def clear(self):
with self.session() as session:
engine = session.get_bind()
_Base.metadata.drop_all(engine)
_Base.metadata.create_all(engine)
def dump_by_table(self):
with self.session() as session:
engine = session.get_bind()
meta = MetaData()
meta.reflect(engine)
result = {}
for table in meta.sorted_tables:
result[table.name] = [dict(row) for row in engine.execute(table.select())]
LOGGER.info(result)
return result
def dump_all(self):
with self.session() as session:
engine = session.get_bind()
meta = MetaData()
meta.reflect(engine)
result = []
for table in meta.sorted_tables:
for row in engine.execute(table.select()):
result.append((table.name, dict(row)))
return result
def get_object(self, model_class: _Base, main_key: str, raise_if_not_found=False):
filt = {model_class.main_pk_name(): main_key}
with self.session() as session:
get = session.query(model_class).filter_by(**filt).one_or_none()
if not get:
if raise_if_not_found:
raise NotFoundException(model_class.__name__.replace('Model', ''), main_key)
dump = None
if hasattr(get, 'dump'):
dump = get.dump()
return get, dump
def get_object_filter(self, model_class: _Base, filt, raise_if_not_found=False):
with self.session() as session:
get = session.query(model_class).filter_by(**filt).all()
if not get:
if raise_if_not_found:
raise NotFoundException(model_class.__name__.replace('Model', ''))
else:
return None, None
if isinstance(get, list):
return get, [obj.dump() for obj in get]
return get, get.dump()
def get_or_create(self, model_class: _Base, key_parts: str, filt=None) -> Tuple[_Base, bool]:
if not filt:
filt = {model_class.main_pk_name(): key_parts}
with self.session() as session:
get = session.query(model_class).filter_by(**filt).one_or_none()
if get:
return get, False
else:
obj = model_class()
session.add(obj)
session.commit()
return obj, True