from typing import Tuple, List

from sqlalchemy import MetaData
from sqlalchemy.orm import Session
from context.service.database.Base import Base
import logging
from common.orm.backend.Tools import key_to_str

from common.rpc_method_wrapper.ServiceExceptions import NotFoundException

LOGGER = logging.getLogger(__name__)


class Database(Session):
    def __init__(self, session):
        super().__init__()
        self.session = session

    def get_all(self, model):
        result = []
        with self.session() as session:
            for entry in session.query(model).all():
                result.append(entry)

        return result

    def create_or_update(self, model):
        with self.session() as session:
            att = getattr(model, model.main_pk_name())
            filt = {model.main_pk_name(): att}
            found = session.query(type(model)).filter_by(**filt).one_or_none()
            if found:
                found = True
            else:
                found = False

            session.merge(model)
            session.commit()
        return model, found

    def create(self, model):
        with self.session() as session:
            session.add(model)
            session.commit()
        return model

    def remove(self, model, filter_d):
        model_t = type(model)
        with self.session() as session:
            session.query(model_t).filter_by(**filter_d).delete()
            session.commit()


    def clear(self):
        with self.session() as session:
            engine = session.get_bind()
        Base.metadata.drop_all(engine)
        Base.metadata.create_all(engine)

    def dump_by_table(self):
        with self.session() as session:
            engine = session.get_bind()
        meta = MetaData()
        meta.reflect(engine)
        result = {}

        for table in meta.sorted_tables:
            result[table.name] = [dict(row) for row in engine.execute(table.select())]
        LOGGER.info(result)
        return result

    def dump_all(self):
        with self.session() as session:
            engine = session.get_bind()
        meta = MetaData()
        meta.reflect(engine)
        result = []

        for table in meta.sorted_tables:
            for row in engine.execute(table.select()):
                result.append((table.name, dict(row)))
        LOGGER.info(result)

        return result

    def get_object(self, model_class: Base, main_key: str, raise_if_not_found=False):
        filt = {model_class.main_pk_name(): main_key}
        with self.session() as session:
            get = session.query(model_class).filter_by(**filt).one_or_none()

            if not get:
                if raise_if_not_found:
                    raise NotFoundException(model_class.__name__.replace('Model', ''), main_key)

            return get
    def get_or_create(self, model_class: Base, key_parts: List[str]
                      ) -> Tuple[Base, bool]:

        str_key = key_to_str(key_parts)
        filt = {model_class.main_pk_name(): key_parts}
        with self.session() as session:
            get = session.query(model_class).filter_by(**filt).one_or_none()
            if get:
                return get, False
            else:
                obj = model_class()
                setattr(obj, model_class.main_pk_name(), str_key)
                LOGGER.info(obj.dump())
                session.add(obj)
                session.commit()
                return obj, True