Skip to content
Snippets Groups Projects
RedisDatabaseEngine.py 4.07 KiB
Newer Older
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
import uuid
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
from typing import Dict, List, Set, Tuple
from redis.client import Redis
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
from common.Settings import get_setting
from common.database.engines._DatabaseEngine import _DatabaseEngine
from common.database.engines.redis.Mutex import Mutex
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed

KEY_ENTIRE_DATABASE_LOCK = 'everything'

class RedisDatabaseEngine(_DatabaseEngine):
    def __init__(self, **settings) -> None:
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
        host = get_setting('REDIS_SERVICE_HOST', settings=settings)
        port = get_setting('REDIS_SERVICE_PORT', settings=settings)
        dbid = get_setting('REDIS_DATABASE_ID',  settings=settings)
Lluis Gifre Renom's avatar
Lluis Gifre Renom committed
        self._client = Redis.from_url('redis://{host}:{port}/{dbid}'.format(host=host, port=port, dbid=dbid))
        self._mutex = Mutex(self._client)

    def lock(self) -> Tuple[bool, str]:
        owner_key = str(uuid.uuid4())
        return self._mutex.acquire(KEY_ENTIRE_DATABASE_LOCK, owner_key=owner_key, blocking=True)

    def unlock(self, owner_key : str) -> bool:
        return self._mutex.release(KEY_ENTIRE_DATABASE_LOCK, owner_key)

    def keys(self) -> list:
        return [k.decode('UTF-8') for k in self._client.keys()]

    def exists(self, key_name : str) -> bool:
        return self._client.exists(key_name) == 1

    def delete(self, key_name : str) -> bool:
        return self._client.delete(key_name) == 1

    def set_has(self, key_name : str, item : str) -> bool:
        return self._client.sismember(key_name, item) == 1

    def set_add(self, key_name : str, item : str) -> None:
        self._client.sadd(key_name, item)

    def set_remove(self, key_name : str, item : str) -> None:
        self._client.srem(key_name, item)

    def list_push_last(self, key_name : str, item : str) -> None:
        self._client.rpush(key_name, item)

    def list_get_all(self, key_name : str) -> List[str]:
        return list(map(lambda m: m.decode('UTF-8'), self._client.lrange(key_name, 0, -1)))

    def list_remove_first_occurrence(self, key_name : str, item: str) -> None:
        self._client.lrem(key_name, 1, item)

    def dict_get(self, key_name : str, fields : List[str] = []) -> Dict[str, str]:
        if len(fields) == 0:
            keys_values = self._client.hgetall(key_name).items()
        else:
            fields = list(fields)
            keys_values = zip(fields, self._client.hmget(key_name, fields))

        attributes = {}
        for key,value in keys_values:
            str_key = key.decode('UTF-8') if isinstance(key, bytes) else key
            attributes[str_key] = value.decode('UTF-8') if isinstance(value, bytes) else value
        return attributes

    def dict_update(
        self, key_name : str, update_fields : Dict[str, str] = {}, remove_fields : Set[str] = set()) -> None:
        if len(remove_fields) > 0:
            self._client.hdel(key_name, *remove_fields)

        if len(update_fields) > 0:
            self._client.hset(key_name, mapping=update_fields)

    def dict_delete(self, key_name : str, fields : List[str] = []) -> None:
        if len(fields) == 0:
            self._client.delete(key_name)
        else:
            self._client.hdel(key_name, set(fields))

    def dump(self) -> List[Tuple[str, str, str]]:
        entries = []
        for key_name in self._client.keys():
            key_name = key_name.decode('UTF-8')
            key_type = self._client.type(key_name)
            if key_type is not None: key_type = key_type.decode('UTF-8')
            key_type = {
                'hash'  : 'dict',
                'list'  : 'list',
                'set'   : 'set',
                'string': 'str',
            }.get(key_type)
            key_content = {
                'dict': lambda key: {k.decode('UTF-8'):v.decode('UTF-8') for k,v in self._client.hgetall(key).items()},
                'list': lambda key: [m.decode('UTF-8') for m in self._client.lrange(key, 0, -1)],
                'set' : lambda key: {m.decode('UTF-8') for m in self._client.smembers(key)},
                'str' : lambda key: self._client.get(key).decode('UTF-8'),
            }.get(key_type, lambda key: 'UNSUPPORTED_TYPE')
            entries.append((key_name, key_type, key_content(key_name)))
        return entries