Newer
Older
# Copyright 2021-2023 H2020 TeraFlow (https://www.teraflow-h2020.eu/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
from common.orm.Database import Database
from common.orm.backend.Tools import key_to_str
from common.orm.fields.ForeignKeyField import ForeignKeyField
from ..Exceptions import ConstraintException, MutexException
from ..fields.PrimaryKeyField import PrimaryKeyField
LOGGER = logging.getLogger(__name__)
DEFAULT_PRIMARY_KEY_NAME = 'pk_auto'
def __prepare__(cls, name : str, bases : Tuple[type, ...], **attrs : Any) -> Mapping[str, Any]:
def __new__(cls, name : str, bases : Tuple[type, ...], attrs : NoDupOrderedDict[str, Any]):
for key, value in attrs.items():
if not isinstance(value, Field): continue
if not isinstance(value, PrimaryKeyField): continue
if pk_field_name is None:
pk_field_name = key
raise AttributeError('PrimaryKeyField for Model({:s}) already set to attribute({:s})'.format(
str(name), str(pk_field_name)))
if pk_field_name is None:
if DEFAULT_PRIMARY_KEY_NAME in attrs.keys():
msg = 'PrimaryKeyField for Model({:s}) not defined and attribute "{:s}" already used. '\
'Leave attribute name "{:s}" for automatic PrimaryKeyField, or set a PrimaryKeyField.'
raise AttributeError(msg.format(str(name), DEFAULT_PRIMARY_KEY_NAME, DEFAULT_PRIMARY_KEY_NAME))
pk_field_name = DEFAULT_PRIMARY_KEY_NAME
attrs[pk_field_name] = PrimaryKeyField(name=pk_field_name)
field_names.append(pk_field_name)
cls_obj = super().__new__(cls, name, bases, dict(attrs))
setattr(cls_obj, '_pk_field_name', pk_field_name)
setattr(cls_obj, '_field_names_list', field_names)
setattr(cls_obj, '_field_names_set', set(field_names))
return cls_obj
KEYWORD_INSTANCES = 'instances'
KEYWORD_LOCK = 'lock'
KEYWORD_REFERENCES = 'references'
KEYWORD_STORED = '_stored'
@classmethod
def get_backend_key_instances(cls) -> str:
return key_to_str(['{:s}'.format(cls.__name__), KEYWORD_INSTANCES])
@classmethod
def get_backend_key_instance(cls, primary_key : str) -> str:
return '{:s}[{:s}]'.format(cls.__name__, primary_key)
@classmethod
def get_backend_key_references(cls, primary_key : str) -> str:
match = re.match(r'^[a-zA-Z0-9\_]+\[([^\]]*)\]', primary_key)
if not match: primary_key = cls.get_backend_key_instance(primary_key)
return key_to_str([primary_key, KEYWORD_REFERENCES])
@staticmethod
def get_backend_key_lock(backend_key : str) -> str:
if backend_key.endswith(KEYWORD_LOCK): return backend_key
return key_to_str([backend_key, KEYWORD_LOCK])
@staticmethod
def get_backend_key_locks(backend_keys : List[str]) -> List[str]:
return [Model.get_backend_key_lock(backend_key) for backend_key in backend_keys]
@classmethod
def backend_key__to__instance_key(cls, backend_key : str) -> str:
class_name = cls.__name__
if backend_key.startswith(class_name):
match = re.match(r'^{:s}\[([^\]]*)\]'.format(class_name), backend_key)
if match: return match.group(1)
return backend_key
def __init__(self, database : Database, primary_key : str, auto_load : bool = True) -> None:
if not isinstance(database, Database):
str_class_path = '{}.{}'.format(Database.__module__, Database.__name__)
raise AttributeError('database must inherit from {}'.format(str_class_path))
self._model_class = type(self)
self._class_name = self._model_class.__name__
pk_field_name = self._pk_field_name # pylint: disable=no-member
pk_field_instance : 'PrimaryKeyField' = getattr(self._model_class, pk_field_name)
primary_key = pk_field_instance.validate(primary_key)
primary_key = self.backend_key__to__instance_key(primary_key)
setattr(self, pk_field_name, primary_key)
self._database = database
self._backend = database.backend
self._instance_key : str = self.get_backend_key_instance(primary_key)
self._instances_key : str = self.get_backend_key_instances()
self._references_key : str = self.get_backend_key_references(primary_key)
@property
def database(self) -> Database: return self._database
def instance_key(self) -> str: return self._instance_key
def lock(self, extra_keys : List[List[str]] = [], blocking : bool = True):
while True:
lock_keys = Model.get_backend_key_locks(
[self._instance_key, self._instances_key, self._references_key] + extra_keys)
acquired,self._owner_key = self._backend.lock(lock_keys, owner_key=self._owner_key)
if acquired: return
if not blocking: break
raise MutexException('Unable to lock keys {:s} using owner_key {:s}'.format(
str(lock_keys), str(self._owner_key)))
def unlock(self, extra_keys : List[List[str]] = []):
lock_keys = Model.get_backend_key_locks(
[self._instance_key, self._instances_key, self._references_key] + extra_keys)
released = self._backend.unlock(lock_keys, self._owner_key)
if released: return
raise MutexException('Unable to unlock keys {:s} using owner_key {:s}'.format(
pk_field_name = self._pk_field_name # pylint: disable=no-member
try:
self.lock()
attributes = self._backend.dict_get(self._instance_key)
if attributes is None or len(attributes) == 0: return False
for field_name in self._field_names_list: # pylint: disable=no-member
if field_name == pk_field_name: continue
if field_name not in attributes: continue
raw_field_value = attributes[field_name]
field_instance : 'Field' = getattr(self._model_class, field_name)
field_value = field_instance.deserialize(raw_field_value)
if isinstance(field_instance, ForeignKeyField):
setattr(self, field_name + KEYWORD_STORED, field_value)
field_value = field_instance.foreign_model(self._database, field_value, auto_load=True)
setattr(self, field_name, field_value)
attributes : Dict[str, Any] = dict()
required_keys : Set[str] = set()
foreign_additions : Dict[str, str] = dict()
foreign_removals : Dict[str, str] = dict()
for field_name in self._field_names_list: # pylint: disable=no-member
field_value = getattr(self, field_name)
field_instance : 'Field' = getattr(self._model_class, field_name)
serialized_field_value = field_instance.serialize(field_value)
if (serialized_field_value is None) and (not field_instance.required): continue
if isinstance(field_instance, ForeignKeyField):
foreign_reference = '{:s}:{:s}'.format(self._instance_key, field_name)
field_value_stored = getattr(self, field_name + KEYWORD_STORED, None)
foreign_removals[self.get_backend_key_references(field_value_stored)] = foreign_reference
foreign_additions[self.get_backend_key_references(serialized_field_value)] = foreign_reference
required_keys.add(serialized_field_value)
attributes[field_name] = serialized_field_value
extra_keys = []
extra_keys.extend(list(foreign_removals.keys()))
extra_keys.extend(list(foreign_additions.keys()))
try:
self.lock(extra_keys=extra_keys)
not_exists = [
str(required_key)
for required_key in required_keys
if not self._backend.exists(required_key)]
raise ConstraintException('Required Keys ({:s}) does not exist'.format(', '.join(sorted(not_exists))))
self._backend.dict_update(self._instance_key, attributes)
self._backend.set_add(self._instances_key, self._instance_key)
for serialized_field_value,foreign_reference in foreign_removals.items():
self._backend.set_remove(serialized_field_value, foreign_reference)
for serialized_field_value,foreign_reference in foreign_additions.items():
self._backend.set_add(serialized_field_value, foreign_reference)
finally:
self.unlock(extra_keys=extra_keys)
for serialized_field_value,foreign_reference in foreign_additions.items():
setattr(self, (foreign_reference.rsplit(':', 1)[-1]) + KEYWORD_STORED, field_value_stored)
foreign_removals : Dict[str, str] = {}
for field_name in self._field_names_list: # pylint: disable=no-member
field_instance : 'Field' = getattr(self._model_class, field_name)
if not isinstance(field_instance, ForeignKeyField): continue
foreign_reference = '{:s}:{:s}'.format(self._instance_key, field_name)
field_value_stored = getattr(self, field_name + KEYWORD_STORED, None)
if field_value_stored is None: continue
foreign_removals[self.get_backend_key_references(field_value_stored)] = foreign_reference
extra_keys = []
extra_keys.extend(list(foreign_removals.keys()))
try:
self.lock(extra_keys=extra_keys)
if self._backend.exists(self._references_key):
references = self._backend.set_get_all(self._references_key)
raise ConstraintException('Instance is used by Keys ({:s})'.format(', '.join(sorted(references))))
self._backend.delete(self._instance_key)
self._backend.set_remove(self._instances_key, self._instance_key)
for serialized_field_value,foreign_reference in foreign_removals.items():
self._backend.set_remove(serialized_field_value, foreign_reference)
finally:
self.unlock(extra_keys=extra_keys)
@staticmethod
def get_model_name(model_or_str) -> str:
if isinstance(model_or_str, str):
return model_or_str
if (type(model_or_str).__name__ == 'MetaModel') and issubclass(model_or_str, Model):
return model_or_str.__name__
raise Exception()
self, filter_by_models : Optional[Union[type, List[type], Set[type], Tuple[type]]] = None
) -> Set[Tuple[str, str]]:
try:
self.lock()
if not self._backend.exists(self._references_key): return {}
references = self._backend.set_get_all(self._references_key)
try:
if filter_by_models is None:
pass
elif isinstance(filter_by_models, str):
filter_by_models = {filter_by_models}
elif isinstance(filter_by_models, (list, set, tuple)):
filter_by_models = {Model.get_model_name(model_or_str) for model_or_str in filter_by_models}
elif (type(filter_by_models).__name__ == 'MetaModel') and issubclass(filter_by_models, Model):
filter_by_models = {Model.get_model_name(filter_by_models)}
else:
raise Exception()
except Exception as e:
msg = 'filter_by_models({:s}) unsupported. Expected a type or a list/set of types. Optionally, keep '\
'it as None to retrieve all the references pointing to this instance.'
raise AttributeError(msg.format(str(filter_by_models))) from e
if filter_by_models:
references = filter(lambda instance_key: instance_key.split('[', 1)[0] in filter_by_models, references)
return {tuple(reference.rsplit(':', 1)) for reference in references}
finally:
self.unlock()
@classmethod
def get_primary_keys(cls, database : Database):
backend = database.backend
key_model_instances = cls.get_backend_key_instances()
key_model_instances_lock = cls.get_backend_key_lock(key_model_instances)
acquired,owner_key = backend.lock(key_model_instances_lock)
if not acquired:
raise MutexException('Unable to lock keys {:s}'.format(
str(key_model_instances_lock)))
instance_keys = backend.set_get_all(key_model_instances)
released = backend.unlock(key_model_instances_lock, owner_key)
if not released:
raise MutexException('Unable to unlock keys {:s} using owner_key {:s}'.format(
str(key_model_instances_lock), str(owner_key)))
return instance_keys
def dump_id(self) -> Dict:
raise NotImplementedError()
def dump(self) -> Dict:
raise NotImplementedError()
def __repr__(self) -> str:
pk_field_name = self._pk_field_name # pylint: disable=no-member
arguments = ', '.join(
'{:s}={:s}{:s}'.format(
name, repr(getattr(self, name)), '(PK)' if name == pk_field_name else '')
for name in self._field_names_list # pylint: disable=no-member
)
return '{:s}({:s})'.format(self._class_name, arguments)