Skip to content
Snippets Groups Projects
Mutex.py 6.49 KiB
Newer Older
import random, time, uuid
from typing import Set, Union
from redis.client import Redis

KEY_LOCK = '{}/lock'
MIN_WAIT_TIME = 0.01

class Mutex:
    def __init__(self, redis_client: Redis) -> None:
        if not isinstance(redis_client, Redis):
            raise AttributeError('redis_client must be an instance of redis.client.Redis')
        self.redis_client = redis_client
        self.script_release = None
        self.script_refresh_expire = None
        self.__register_scripts()

    def __register_scripts(self):
        # Script mutex_release
        #   Description: atomic script to release a set of mutex keys, only if all mutex keys are owned by the caller.
        #                if owner_key matches key stored in all mutexes, remove all mutexes and return 1. if some key
        #                does not match, do nothing and return 0.
        #   Keys: set of entity_keys to be released
        #   Args: owner_key
        #   Ret : 1 if all keys have been released, 0 otherwise (no action performed)
        #   Use : acquired = (int(self.script_release(keys=['mutex1', 'mutex2'], args=[owner_key])) == 1)
        self.script_release = self.redis_client.register_script('\n'.join([
            "for _,key in ipairs(KEYS) do",
            "    local owner_key = redis.call('get', key)",
            "    if owner_key ~= ARGV[1] then return 0 end",
            "end",
            "for _,key in ipairs(KEYS) do",
            "    redis.call('del', key)",
            "end",
            "return 1",
        ]))

        # Script mutex_refresh_expire
        #   Description: atomic script to refresh expiracy of a set of mutex keys, only if all of them are owned by the
        #                caller. if owner_key matches key stored in all mutexes, refresh expiracy on all mutexes and
        #                return 1. if some key does not match, do nothing and return 0.
        #   Keys: set of entity_keys to be refreshed
        #   Args: owner_key, expiracy_seconds
        #   Ret : 1 if all keys have been refreshed, 0 otherwise (no action performed)
        #   Use : done = (int(self.script_refresh_expire(keys=['mutex1', 'mutex2'], args=[owner_key, seconds])) == 1)
        self.script_refresh_expire = self.redis_client.register_script('\n'.join([
            "for _,key in ipairs(KEYS) do",
            "    local owner_key = redis.call('get', key)",
            "    if owner_key ~= ARGV[1] then return 0 end",
            "end",
            "for _,key in ipairs(KEYS) do",
            "    redis.call('expire', key, ARGV[2])",
            "end",
            "return 1",
        ]))

    def acquire(self, entity_key_or_keys : Union[str, Set[str]], owner_key : Union[str, None] = None,
                blocking : bool = True, timeout : Union[float, int] = 5,
                expiracy_seconds : Union[float, int, None] = None):
        # Atomically set all entity_keys or none of them.
        # entity_key_or_keys contains either a string with a specific entity key or a set with all entity keys to be
        # set atomically.
        # owner_key enables to specify the desired key to use to mark the mutex. When releasing, the owner_key must be
        # correct, otherwise, the key will not be released. It can also be used to check if mutex is still owned by
        # oneself or was lost and acquired by another party. If set to None, a random key is generated and returned
        # together with the acquired boolean value.
        # blocking defines wether the acquisition should be blocking, meaning that acquisition will be retired with
        # random increments until timeout timeout is elapsed.
        # Optionally, an expiracy_seconds period can be specified in expiracy_seconds. If mutex is not released after
        # that period of time, the mutex will be released automatically.
        # If mutex(es) is(are) acquired, the method returns True and the owner_key used to create the lock; otherwise,
        # False and None owner_key are returned.

        owner_key = owner_key or str(uuid.uuid4())
        entity_keys = entity_key_or_keys if isinstance(entity_key_or_keys, set) else {str(entity_key_or_keys)}
        entity_key_map = {KEY_LOCK.format(entity_key):owner_key for entity_key in entity_keys}
        acquired = False
        if blocking:
            remaining_wait_time = timeout
            while not acquired:
                acquired = (self.redis_client.msetnx(entity_key_map) == 1)
                if acquired: break
                if remaining_wait_time < MIN_WAIT_TIME: return False, None
                wait_time = remaining_wait_time * random.random()
                remaining_wait_time -= wait_time
                time.sleep(wait_time)
        else:
            acquired = (self.redis_client.msetnx(entity_key_map) == 1)

        if not acquired: return False, None

        if expiracy_seconds is not None:
            pipeline = self.redis_client.pipeline()
            for entity_key in entity_key_map.keys(): pipeline.expire(entity_key, expiracy_seconds)
            pipeline.execute()

        return True, owner_key

    def release(self, entity_key_or_keys : Union[str, Set[str]], owner_key : str):
        # release mutex keys only if all of them are owned by the caller
        # return True if succeeded, False (nothing changed) otherwise
        entity_keys = entity_key_or_keys if isinstance(entity_key_or_keys, set) else {str(entity_key_or_keys)}
        entity_keys = {KEY_LOCK.format(entity_key) for entity_key in entity_keys}
        return int(self.script_release(keys=list(entity_keys), args=[owner_key])) == 1

    def acquired(self, entity_key : str, owner_key : str):
        # check if a mutex is owned by the owner with owner_key
        value = self.redis_client.get(KEY_LOCK.format(entity_key))
        if(value is None): return(False)
        return str(value) == owner_key

    def get_ttl(self, entity_key : str):
        # check a mutex's time to live
        return self.redis_client.ttl(KEY_LOCK.format(entity_key))

    def refresh_expiracy(self, entity_key_or_keys : Union[str, Set[str]], owner_key : str,
                         expiracy_seconds : Union[float, int]):
        # refresh expiracy on specified mutex keys only if all of them are owned by the caller
        # return True if succeeded, False (nothing changed) otherwise
        entity_keys = entity_key_or_keys if isinstance(entity_key_or_keys, set) else {str(entity_key_or_keys)}
        entity_keys = {KEY_LOCK.format(entity_key) for entity_key in entity_keys}
        return int(self.script_refresh_expire(keys=entity_keys, args=[owner_key, expiracy_seconds])) == 1