from __future__ import annotations
import logging, re
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple
from common.orm.Database import Database
from common.orm.backend.Tools import key_to_str
from common.orm.fields.ForeignKeyField import ForeignKeyField
from ..Exceptions import ConstraintException, MutexException
from ..fields.Field import Field
from ..fields.PrimaryKeyField import PrimaryKeyField
from .Tools import NoDupOrderedDict

LOGGER = logging.getLogger(__name__)
DEFAULT_PRIMARY_KEY_NAME = 'pk_auto'

class MetaModel(type):
    @classmethod
    def __prepare__(cls, name : str, bases : Tuple[type, ...], **attrs : Any) -> Mapping[str, Any]:
        return NoDupOrderedDict()

    def __new__(cls, name : str, bases : Tuple[type, ...], attrs : NoDupOrderedDict[str, Any]):
        field_names = list()
        pk_field_name = None
        for key, value in attrs.items():
            if not isinstance(value, Field): continue
            value.name = key
            field_names.append(key)
            if not isinstance(value, PrimaryKeyField): continue
            if pk_field_name is None:
                pk_field_name = key
                continue
            raise AttributeError('PrimaryKeyField for Model({:s}) already set to attribute({:s})'.format(
                str(name), str(pk_field_name)))
        if pk_field_name is None:
            if DEFAULT_PRIMARY_KEY_NAME in attrs.keys():
                msg = 'PrimaryKeyField for Model({:s}) not defined and attribute "{:s}" already used. '\
                      'Leave attribute name "{:s}" for automatic PrimaryKeyField, or set a PrimaryKeyField.'
                raise AttributeError(msg.format(str(name), DEFAULT_PRIMARY_KEY_NAME, DEFAULT_PRIMARY_KEY_NAME))
            pk_field_name = DEFAULT_PRIMARY_KEY_NAME
            attrs[pk_field_name] = PrimaryKeyField(name=pk_field_name)
            field_names.append(pk_field_name)
        cls_obj = super().__new__(cls, name, bases, dict(attrs))
        setattr(cls_obj, '_pk_field_name', pk_field_name)
        setattr(cls_obj, '_field_names_list', field_names)
        setattr(cls_obj, '_field_names_set', set(field_names))
        return cls_obj

class Model(metaclass=MetaModel):
    def __init__(self, database : Database, primary_key : str, auto_load : bool = True) -> None:
        if not isinstance(database, Database):
            str_class_path = '{}.{}'.format(Database.__module__, Database.__name__)
            raise AttributeError('database must inherit from {}'.format(str_class_path))
        self._model_class = type(self)
        self._class_name = self._model_class.__name__
        pk_field_name = self._pk_field_name # pylint: disable=no-member
        pk_field_instance : 'PrimaryKeyField' = getattr(self._model_class, pk_field_name)
        primary_key = pk_field_instance.validate(primary_key)
        if primary_key.startswith(self._class_name):
            match = re.match(r'^{:s}\[([^\]]*)\]'.format(self._class_name), primary_key)
            if match: primary_key = match.group(1)
        setattr(self, pk_field_name, primary_key)
        self._database = database
        self._backend = database.backend
        self._instance_key : str = '{:s}[{:s}]'.format(self._class_name, primary_key)
        self._references_key : str = key_to_str([self._instance_key, 'references'])
        self._owner_key : Optional[str] = None
        if auto_load: self.load()

    @property
    def instance_key(self) -> str: return self._instance_key

    def lock(self, extra_keys : List[List[str]] = []):
        lock_keys = [self._instance_key, self._references_key] + extra_keys
        lock_keys = [key_to_str([lock_key, 'lock']) for lock_key in lock_keys]
        acquired,self._owner_key = self._backend.lock(lock_keys, owner_key=self._owner_key)
        if acquired: return
        raise MutexException('Unable to lock keys {:s} using owner_key {:s}'.format(
            str(lock_keys), str(self._owner_key)))

    def unlock(self, extra_keys : List[List[str]] = []):
        lock_keys = [self._instance_key, self._references_key] + extra_keys
        lock_keys = [key_to_str([lock_key, 'lock']) for lock_key in lock_keys]
        released = self._backend.unlock(lock_keys, self._owner_key)
        if released: return
        raise MutexException('Unable to unlock keys {:s} using owner_key {:s}'.format(
            str(lock_keys), str(self._owner_key)))

    def load(self) -> None:
        pk_field_name = self._pk_field_name # pylint: disable=no-member

        try:
            self.lock()

            attributes = self._backend.dict_get(self._instance_key)
            if attributes is None: return
            for field_name in self._field_names_list: # pylint: disable=no-member
                if field_name == pk_field_name: continue
                if field_name not in attributes: continue
                raw_field_value = attributes[field_name]
                field_instance : 'Field' = getattr(self._model_class, field_name)
                field_value = field_instance.deserialize(raw_field_value)
                if isinstance(field_instance, ForeignKeyField):
                    setattr(self, field_name + '_stored', field_value)
                    field_value = field_instance.foreign_model(self._database, field_value, auto_load=True)
                setattr(self, field_name, field_value)
        finally:
            self.unlock()

    def save(self) -> None:
        attributes : Dict[str, Any] = dict()
        required_keys : Set[str] = set()
        foreign_additions : Dict[str, str] = dict()
        foreign_removals : Dict[str, str] = dict()
        for field_name in self._field_names_list: # pylint: disable=no-member
            field_value = getattr(self, field_name)
            field_instance : 'Field' = getattr(self._model_class, field_name)
            serialized_field_value = field_instance.serialize(field_value)
            if (serialized_field_value is None) and (not field_instance.required): continue
            if isinstance(field_instance, ForeignKeyField):
                foreign_reference = '{:s}:{:s}'.format(self._instance_key, field_name)
                field_value_stored = getattr(self, field_name + '_stored', None)
                if field_value_stored is not None:
                    foreign_removals[key_to_str([field_value_stored, 'references'])] = foreign_reference
                foreign_additions[key_to_str([serialized_field_value, 'references'])] = foreign_reference
                required_keys.add(serialized_field_value)
            attributes[field_name] = serialized_field_value

        extra_keys = []
        extra_keys.extend(list(foreign_removals.keys()))
        extra_keys.extend(list(foreign_additions.keys()))

        try:
            self.lock(extra_keys=extra_keys)

            not_exists = []
            for required_key in required_keys:
                if self._backend.exists(required_key): continue
                not_exists.append('{:s}'.format(str(required_key)))
            if len(not_exists) > 0:
                raise ConstraintException('Required Keys ({:s}) does not exist'.format(', '.join(sorted(not_exists))))

            self._backend.dict_update(self._instance_key, attributes)
            for serialized_field_value,foreign_reference in foreign_removals.items():
                self._backend.set_remove(serialized_field_value, foreign_reference)

            for serialized_field_value,foreign_reference in foreign_additions.items():
                self._backend.set_add(serialized_field_value, foreign_reference)
        finally:
            self.unlock(extra_keys=extra_keys)

        for serialized_field_value,foreign_reference in foreign_additions.items():
            setattr(self, (foreign_reference.rsplit(':', 1)[-1]) + '_stored', field_value_stored)

    def delete(self) -> None:
        foreign_removals : Dict[str, str] = {}
        for field_name in self._field_names_list: # pylint: disable=no-member
            field_instance : 'Field' = getattr(self._model_class, field_name)
            if not isinstance(field_instance, ForeignKeyField): continue
            foreign_reference = '{:s}:{:s}'.format(self._instance_key, field_name)
            field_value_stored = getattr(self, field_name + '_stored', None)
            if field_value_stored is None: continue
            foreign_removals[key_to_str([field_value_stored, 'references'])] = foreign_reference

        extra_keys = []
        extra_keys.extend(list(foreign_removals.keys()))

        try:
            self.lock(extra_keys=extra_keys)

            if self._backend.exists(self._references_key):
                references = self._backend.set_get_all(self._references_key)
                raise ConstraintException('Instance is used by Keys ({:s})'.format(', '.join(sorted(references))))

            self._backend.delete(self._instance_key)
            for serialized_field_value,foreign_reference in foreign_removals.items():
                self._backend.set_remove(serialized_field_value, foreign_reference)
        finally:
            self.unlock(extra_keys=extra_keys)

    def references(self) -> Set[Tuple[str, str]]:
        try:
            self.lock()
            if self._backend.exists(self._references_key):
                references = self._backend.set_get_all(self._references_key)
                return {tuple(reference.rsplit(':', 1)) for reference in references}
            return {}
        finally:
            self.unlock()

    def dump_id(self) -> Dict:
        raise NotImplementedError()

    def dump(self) -> Dict:
        raise NotImplementedError()

    def __repr__(self) -> str:
        pk_field_name = self._pk_field_name # pylint: disable=no-member
        arguments = ', '.join(
            '{:s}={:s}{:s}'.format(
                name, repr(getattr(self, name)), '(PK)' if name == pk_field_name else '')
            for name in self._field_names_list # pylint: disable=no-member
        )
        return '{:s}({:s})'.format(self._class_name, arguments)
