from typing import Tuple, List

from sqlalchemy import MetaData
from sqlalchemy.orm import Session, joinedload
from context.service.database.Base import Base
import logging
from common.orm.backend.Tools import key_to_str

from common.rpc_method_wrapper.ServiceExceptions import NotFoundException

LOGGER = logging.getLogger(__name__)


class Database(Session):
    def __init__(self, session):
        super().__init__()
        self.session = session

    def get_session(self):
        return self.session

    def get_all(self, model):
        result = []
        with self.session() as session:
            for entry in session.query(model).all():
                result.append(entry)

        return result

    def create_or_update(self, model):
        with self.session() as session:
            att = getattr(model, model.main_pk_name())
            filt = {model.main_pk_name(): att}
            t_model = type(model)
            obj = session.query(t_model).filter_by(**filt).one_or_none()

            if obj:
                for key in obj.__table__.columns.keys():
                    setattr(obj, key, getattr(model, key))
                found = True
                session.commit()
                return obj, found
            else:
                found = False
                session.add(model)
                session.commit()
                return model, found

    def create(self, model):
        with self.session() as session:
            session.add(model)
            session.commit()
        return model

    def remove(self, model, filter_d):
        model_t = type(model)
        with self.session() as session:
            session.query(model_t).filter_by(**filter_d).delete()
            session.commit()


    def clear(self):
        with self.session() as session:
            engine = session.get_bind()
        Base.metadata.drop_all(engine)
        Base.metadata.create_all(engine)

    def dump_by_table(self):
        with self.session() as session:
            engine = session.get_bind()
        meta = MetaData()
        meta.reflect(engine)
        result = {}

        for table in meta.sorted_tables:
            result[table.name] = [dict(row) for row in engine.execute(table.select())]
        LOGGER.info(result)
        return result

    def dump_all(self):
        with self.session() as session:
            engine = session.get_bind()
        meta = MetaData()
        meta.reflect(engine)
        result = []

        for table in meta.sorted_tables:
            for row in engine.execute(table.select()):
                result.append((table.name, dict(row)))

        return result

    def get_object(self, model_class: Base, main_key: str, raise_if_not_found=False):
        filt = {model_class.main_pk_name(): main_key}
        with self.session() as session:
            get = session.query(model_class).filter_by(**filt).one_or_none()

            if not get:
                if raise_if_not_found:
                    raise NotFoundException(model_class.__name__.replace('Model', ''), main_key)

            dump = None
            if hasattr(get, 'dump'):
                dump = get.dump()
            return get, dump

    def get_object_filter(self, model_class: Base, filt, raise_if_not_found=False):
        with self.session() as session:
            get = session.query(model_class).filter_by(**filt).all()

            if not get:
                if raise_if_not_found:
                    raise NotFoundException(model_class.__name__.replace('Model', ''))
                else:
                    return None, None

            if isinstance(get, list):
                return get, [obj.dump() for obj in get]

            return get, get.dump()

    def get_or_create(self, model_class: Base, key_parts: str, filt=None) -> Tuple[Base, bool]:
        if not filt:
            filt = {model_class.main_pk_name(): key_parts}
        with self.session() as session:
            get = session.query(model_class).filter_by(**filt).one_or_none()
            if get:
                return get, False
            else:
                obj = model_class()
                setattr(obj, model_class.main_pk_name(), key_parts)
                session.add(obj)
                session.commit()
                return obj, True
