from __future__ import annotations import logging, re from typing import Any, Dict, List, Mapping, Optional, Set, Tuple from common.orm.Database import Database from common.orm.backend.Tools import key_to_str from common.orm.fields.ForeignKeyField import ForeignKeyField from ..Exceptions import ConstraintException, MutexException from ..fields.Field import Field from ..fields.PrimaryKeyField import PrimaryKeyField from .Tools import NoDupOrderedDict LOGGER = logging.getLogger(__name__) DEFAULT_PRIMARY_KEY_NAME = 'pk_auto' class MetaModel(type): @classmethod def __prepare__(cls, name : str, bases : Tuple[type, ...], **attrs : Any) -> Mapping[str, Any]: return NoDupOrderedDict() def __new__(cls, name : str, bases : Tuple[type, ...], attrs : NoDupOrderedDict[str, Any]): field_names = list() pk_field_name = None for key, value in attrs.items(): if not isinstance(value, Field): continue value.name = key field_names.append(key) if not isinstance(value, PrimaryKeyField): continue if pk_field_name is None: pk_field_name = key continue raise AttributeError('PrimaryKeyField for Model({:s}) already set to attribute({:s})'.format( str(name), str(pk_field_name))) if pk_field_name is None: if DEFAULT_PRIMARY_KEY_NAME in attrs.keys(): msg = 'PrimaryKeyField for Model({:s}) not defined and attribute "{:s}" already used. '\ 'Leave attribute name "{:s}" for automatic PrimaryKeyField, or set a PrimaryKeyField.' raise AttributeError(msg.format(str(name), DEFAULT_PRIMARY_KEY_NAME, DEFAULT_PRIMARY_KEY_NAME)) pk_field_name = DEFAULT_PRIMARY_KEY_NAME attrs[pk_field_name] = PrimaryKeyField(name=pk_field_name) field_names.append(pk_field_name) cls_obj = super().__new__(cls, name, bases, dict(attrs)) setattr(cls_obj, '_pk_field_name', pk_field_name) setattr(cls_obj, '_field_names_list', field_names) setattr(cls_obj, '_field_names_set', set(field_names)) return cls_obj class Model(metaclass=MetaModel): def __init__(self, database : Database, primary_key : str, auto_load : bool = True) -> None: if not isinstance(database, Database): str_class_path = '{}.{}'.format(Database.__module__, Database.__name__) raise AttributeError('database must inherit from {}'.format(str_class_path)) self._model_class = type(self) self._class_name = self._model_class.__name__ pk_field_name = self._pk_field_name # pylint: disable=no-member pk_field_instance : 'PrimaryKeyField' = getattr(self._model_class, pk_field_name) primary_key = pk_field_instance.validate(primary_key) if primary_key.startswith(self._class_name): match = re.match(r'^{:s}\[([^\]]*)\]'.format(self._class_name), primary_key) if match: primary_key = match.group(1) setattr(self, pk_field_name, primary_key) self._database = database self._backend = database.backend self._instance_key : str = '{:s}[{:s}]'.format(self._class_name, primary_key) self._references_key : str = key_to_str([self._instance_key, 'references']) self._owner_key : Optional[str] = None if auto_load: self.load() @property def instance_key(self) -> str: return self._instance_key def lock(self, extra_keys : List[List[str]] = []): lock_keys = [self._instance_key, self._references_key] + extra_keys lock_keys = [key_to_str([lock_key, 'lock']) for lock_key in lock_keys] acquired,self._owner_key = self._backend.lock(lock_keys, owner_key=self._owner_key) if acquired: return raise MutexException('Unable to lock keys {:s} using owner_key {:s}'.format( str(lock_keys), str(self._owner_key))) def unlock(self, extra_keys : List[List[str]] = []): lock_keys = [self._instance_key, self._references_key] + extra_keys lock_keys = [key_to_str([lock_key, 'lock']) for lock_key in lock_keys] released = self._backend.unlock(lock_keys, self._owner_key) if released: return raise MutexException('Unable to unlock keys {:s} using owner_key {:s}'.format( str(lock_keys), str(self._owner_key))) def load(self) -> None: pk_field_name = self._pk_field_name # pylint: disable=no-member try: self.lock() attributes = self._backend.dict_get(self._instance_key) if attributes is None: return for field_name in self._field_names_list: # pylint: disable=no-member if field_name == pk_field_name: continue if field_name not in attributes: continue raw_field_value = attributes[field_name] field_instance : 'Field' = getattr(self._model_class, field_name) field_value = field_instance.deserialize(raw_field_value) if isinstance(field_instance, ForeignKeyField): setattr(self, field_name + '_stored', field_value) field_value = field_instance.foreign_model(self._database, field_value, auto_load=True) setattr(self, field_name, field_value) finally: self.unlock() def save(self) -> None: attributes : Dict[str, Any] = dict() required_keys : Set[str] = set() foreign_additions : Dict[str, str] = dict() foreign_removals : Dict[str, str] = dict() for field_name in self._field_names_list: # pylint: disable=no-member field_value = getattr(self, field_name) field_instance : 'Field' = getattr(self._model_class, field_name) serialized_field_value = field_instance.serialize(field_value) if (serialized_field_value is None) and (not field_instance.required): continue if isinstance(field_instance, ForeignKeyField): foreign_reference = '{:s}:{:s}'.format(self._instance_key, field_name) field_value_stored = getattr(self, field_name + '_stored', None) if field_value_stored is not None: foreign_removals[key_to_str([field_value_stored, 'references'])] = foreign_reference foreign_additions[key_to_str([serialized_field_value, 'references'])] = foreign_reference required_keys.add(serialized_field_value) attributes[field_name] = serialized_field_value extra_keys = [] extra_keys.extend(list(foreign_removals.keys())) extra_keys.extend(list(foreign_additions.keys())) try: self.lock(extra_keys=extra_keys) not_exists = [] for required_key in required_keys: if self._backend.exists(required_key): continue not_exists.append('{:s}'.format(str(required_key))) if len(not_exists) > 0: raise ConstraintException('Required Keys ({:s}) does not exist'.format(', '.join(sorted(not_exists)))) self._backend.dict_update(self._instance_key, attributes) for serialized_field_value,foreign_reference in foreign_removals.items(): self._backend.set_remove(serialized_field_value, foreign_reference) for serialized_field_value,foreign_reference in foreign_additions.items(): self._backend.set_add(serialized_field_value, foreign_reference) finally: self.unlock(extra_keys=extra_keys) for serialized_field_value,foreign_reference in foreign_additions.items(): setattr(self, (foreign_reference.rsplit(':', 1)[-1]) + '_stored', field_value_stored) def delete(self) -> None: foreign_removals : Dict[str, str] = {} for field_name in self._field_names_list: # pylint: disable=no-member field_instance : 'Field' = getattr(self._model_class, field_name) if not isinstance(field_instance, ForeignKeyField): continue foreign_reference = '{:s}:{:s}'.format(self._instance_key, field_name) field_value_stored = getattr(self, field_name + '_stored', None) if field_value_stored is None: continue foreign_removals[key_to_str([field_value_stored, 'references'])] = foreign_reference extra_keys = [] extra_keys.extend(list(foreign_removals.keys())) try: self.lock(extra_keys=extra_keys) if self._backend.exists(self._references_key): references = self._backend.set_get_all(self._references_key) raise ConstraintException('Instance is used by Keys ({:s})'.format(', '.join(sorted(references)))) self._backend.delete(self._instance_key) for serialized_field_value,foreign_reference in foreign_removals.items(): self._backend.set_remove(serialized_field_value, foreign_reference) finally: self.unlock(extra_keys=extra_keys) def references(self) -> Set[Tuple[str, str]]: try: self.lock() if self._backend.exists(self._references_key): references = self._backend.set_get_all(self._references_key) return {tuple(reference.rsplit(':', 1)) for reference in references} return {} finally: self.unlock() def dump_id(self) -> Dict: raise NotImplementedError() def dump(self) -> Dict: raise NotImplementedError() def __repr__(self) -> str: pk_field_name = self._pk_field_name # pylint: disable=no-member arguments = ', '.join( '{:s}={:s}{:s}'.format( name, repr(getattr(self, name)), '(PK)' if name == pk_field_name else '') for name in self._field_names_list # pylint: disable=no-member ) return '{:s}({:s})'.format(self._class_name, arguments)