import random, time, uuid
from typing import Set, Union
from redis.client import Redis
KEY_LOCK = '{}/lock'
MIN_WAIT_TIME = 0.01
class Mutex:
def __init__(self, redis_client: Redis) -> None:
if not isinstance(redis_client, Redis):
raise AttributeError('redis_client must be an instance of redis.client.Redis')
self.redis_client = redis_client
self.script_release = None
self.script_refresh_expire = None
self.__register_scripts()
def __register_scripts(self):
# Script mutex_release
# Description: atomic script to release a set of mutex keys, only if all mutex keys are owned by the caller.
# if owner_key matches key stored in all mutexes, remove all mutexes and return 1. if some key
# does not match, do nothing and return 0.
# Keys: set of entity_keys to be released
# Args: owner_key
# Ret : 1 if all keys have been released, 0 otherwise (no action performed)
# Use : acquired = (int(self.script_release(keys=['mutex1', 'mutex2'], args=[owner_key])) == 1)
self.script_release = self.redis_client.register_script('\n'.join([
"for _,key in ipairs(KEYS) do",
" local owner_key = redis.call('get', key)",
" if owner_key ~= ARGV[1] then return 0 end",
"end",
"for _,key in ipairs(KEYS) do",
" redis.call('del', key)",
"end",
"return 1",
]))
# Script mutex_refresh_expire
# Description: atomic script to refresh expiracy of a set of mutex keys, only if all of them are owned by the
# caller. if owner_key matches key stored in all mutexes, refresh expiracy on all mutexes and
# return 1. if some key does not match, do nothing and return 0.
# Keys: set of entity_keys to be refreshed
# Args: owner_key, expiracy_seconds
# Ret : 1 if all keys have been refreshed, 0 otherwise (no action performed)
# Use : done = (int(self.script_refresh_expire(keys=['mutex1', 'mutex2'], args=[owner_key, seconds])) == 1)
self.script_refresh_expire = self.redis_client.register_script('\n'.join([
"for _,key in ipairs(KEYS) do",
" local owner_key = redis.call('get', key)",
" if owner_key ~= ARGV[1] then return 0 end",
"end",
"for _,key in ipairs(KEYS) do",
" redis.call('expire', key, ARGV[2])",
"end",
"return 1",
]))
def acquire(self, entity_key_or_keys : Union[str, Set[str]], owner_key : Union[str, None] = None,
blocking : bool = True, timeout : Union[float, int] = 5,
expiracy_seconds : Union[float, int, None] = None):
# Atomically set all entity_keys or none of them.
# entity_key_or_keys contains either a string with a specific entity key or a set with all entity keys to be
# set atomically.
# owner_key enables to specify the desired key to use to mark the mutex. When releasing, the owner_key must be
# correct, otherwise, the key will not be released. It can also be used to check if mutex is still owned by
# oneself or was lost and acquired by another party. If set to None, a random key is generated and returned
# together with the acquired boolean value.
# blocking defines wether the acquisition should be blocking, meaning that acquisition will be retired with
# random increments until timeout timeout is elapsed.
# Optionally, an expiracy_seconds period can be specified in expiracy_seconds. If mutex is not released after
# that period of time, the mutex will be released automatically.
# If mutex(es) is(are) acquired, the method returns True and the owner_key used to create the lock; otherwise,
# False and None owner_key are returned.
owner_key = owner_key or str(uuid.uuid4())
entity_keys = entity_key_or_keys if isinstance(entity_key_or_keys, set) else {str(entity_key_or_keys)}
entity_key_map = {KEY_LOCK.format(entity_key):owner_key for entity_key in entity_keys}
acquired = False
if blocking:
remaining_wait_time = timeout
while not acquired:
acquired = (self.redis_client.msetnx(entity_key_map) == 1)
if acquired: break
if remaining_wait_time < MIN_WAIT_TIME: return False, None
wait_time = remaining_wait_time * random.random()
remaining_wait_time -= wait_time
time.sleep(wait_time)
else:
acquired = (self.redis_client.msetnx(entity_key_map) == 1)
if not acquired: return False, None
if expiracy_seconds is not None:
pipeline = self.redis_client.pipeline()
for entity_key in entity_key_map.keys(): pipeline.expire(entity_key, expiracy_seconds)
pipeline.execute()
return True, owner_key
def release(self, entity_key_or_keys : Union[str, Set[str]], owner_key : str):
# release mutex keys only if all of them are owned by the caller
# return True if succeeded, False (nothing changed) otherwise
entity_keys = entity_key_or_keys if isinstance(entity_key_or_keys, set) else {str(entity_key_or_keys)}
entity_keys = {KEY_LOCK.format(entity_key) for entity_key in entity_keys}
return int(self.script_release(keys=list(entity_keys), args=[owner_key])) == 1
def acquired(self, entity_key : str, owner_key : str):
# check if a mutex is owned by the owner with owner_key
value = self.redis_client.get(KEY_LOCK.format(entity_key))
if(value is None): return(False)
return str(value) == owner_key
def get_ttl(self, entity_key : str):
# check a mutex's time to live
return self.redis_client.ttl(KEY_LOCK.format(entity_key))
def refresh_expiracy(self, entity_key_or_keys : Union[str, Set[str]], owner_key : str,
expiracy_seconds : Union[float, int]):
# refresh expiracy on specified mutex keys only if all of them are owned by the caller
# return True if succeeded, False (nothing changed) otherwise
entity_keys = entity_key_or_keys if isinstance(entity_key_or_keys, set) else {str(entity_key_or_keys)}
entity_keys = {KEY_LOCK.format(entity_key) for entity_key in entity_keys}
return int(self.script_refresh_expire(keys=entity_keys, args=[owner_key, expiracy_seconds])) == 1