Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
RedisDatabaseEngine.py 4.18 KiB
import os, uuid
from typing import Dict, List, Set, Tuple
from redis.client import Redis
from .._DatabaseEngine import _DatabaseEngine
from .Mutex import Mutex

KEY_ENTIRE_DATABASE_LOCK = 'everything'

def get_setting(settings, name):
    value = settings.pop(name, os.environ.get(name, None))
    if value is None: raise Exception('Setting({}) not specified in environment or configuration'.format(name))
    return value

class RedisDatabaseEngine(_DatabaseEngine):
    def __init__(self, **settings) -> None:
        host = get_setting(settings, 'REDISDB_SERVICE_HOST')
        port = get_setting(settings, 'REDISDB_SERVICE_PORT')
        dbid = get_setting(settings, 'REDISDB_DATABASE_ID')
        self._client = Redis.from_url('redis://{host}:{port}/{dbid}'.format(host=host, port=port, dbid=dbid))
        self._mutex = Mutex(self._client)

    def lock(self) -> Tuple[bool, str]:
        owner_key = str(uuid.uuid4())
        return self._mutex.acquire(KEY_ENTIRE_DATABASE_LOCK, owner_key=owner_key, blocking=True)

    def unlock(self, owner_key : str) -> bool:
        return self._mutex.release(KEY_ENTIRE_DATABASE_LOCK, owner_key)

    def keys(self) -> list:
        return [k.decode('UTF-8') for k in self._client.keys()]

    def exists(self, key_name : str) -> bool:
        return self._client.exists(key_name) == 1

    def delete(self, key_name : str) -> bool:
        return self._client.delete(key_name) == 1

    def set_has(self, key_name : str, item : str) -> bool:
        return self._client.sismember(key_name, item) == 1

    def set_add(self, key_name : str, item : str) -> None:
        self._client.sadd(key_name, item)

    def set_remove(self, key_name : str, item : str) -> None:
        self._client.srem(key_name, item)

    def list_push_last(self, key_name : str, item : str) -> None:
        self._client.rpush(key_name, item)

    def list_get_all(self, key_name : str) -> List[str]:
        return list(map(lambda m: m.decode('UTF-8'), self._client.lrange(key_name, 0, -1)))

    def list_remove_first_occurrence(self, key_name : str, item: str) -> None:
        self._client.lrem(key_name, 1, item)

    def dict_get(self, key_name : str, fields : List[str] = []) -> Dict[str, str]:
        if len(fields) == 0:
            keys_values = self._client.hgetall(key_name).items()
        else:
            fields = list(fields)
            keys_values = zip(fields, self._client.hmget(key_name, fields))

        attributes = {}
        for key,value in keys_values:
            str_key = key.decode('UTF-8') if isinstance(key, bytes) else key
            attributes[str_key] = value.decode('UTF-8') if isinstance(value, bytes) else value
        return attributes

    def dict_update(
        self, key_name : str, update_fields : Dict[str, str] = {}, remove_fields : Set[str] = set()) -> None:
        if len(remove_fields) > 0:
            self._client.hdel(key_name, *remove_fields)

        if len(update_fields) > 0:
            self._client.hset(key_name, mapping=update_fields)

    def dict_delete(self, key_name : str, fields : List[str] = []) -> None:
        if len(fields) == 0:
            self._client.delete(key_name)
        else:
            self._client.hdel(key_name, set(fields))

    def dump(self) -> List[Tuple[str, str, str]]:
        entries = []
        for key_name in self._client.keys():
            key_name = key_name.decode('UTF-8')
            key_type = self._client.type(key_name)
            if key_type is not None: key_type = key_type.decode('UTF-8')
            key_type = {
                'hash'  : 'dict',
                'list'  : 'list',
                'set'   : 'set',
                'string': 'str',
            }.get(key_type)
            key_content = {
                'dict': lambda key: {k.decode('UTF-8'):v.decode('UTF-8') for k,v in self._client.hgetall(key).items()},
                'list': lambda key: [m.decode('UTF-8') for m in self._client.lrange(key, 0, -1)],
                'set' : lambda key: {m.decode('UTF-8') for m in self._client.smembers(key)},
                'str' : lambda key: self._client.get(key).decode('UTF-8'),
            }.get(key_type, lambda key: 'UNSUPPORTED_TYPE')
            entries.append((key_name, key_type, key_content(key_name)))
        return entries