Skip to content

Commit

Permalink
Fixbug threading local in coroutine (#120)
Browse files Browse the repository at this point in the history
* New: use contextvars instead of threading local in case of the safety of thread local mechanism being broken by coroutine

* New: related test cases of redis lock in concurrency scene

* Update: change test/dev requirements

* New: change cluster tests according to lock change

* Update: change interval & timeout of test case
  • Loading branch information
NoneGG authored Aug 21, 2019
1 parent 9eabfb0 commit 613dc13
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 20 deletions.
20 changes: 10 additions & 10 deletions aredis/lock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import threading
import time as mod_time
import uuid
import warnings

import contextvars

from aredis.connection import ClusterConnection
from aredis.exceptions import LockError, WatchError
from aredis.utils import b, dummy
Expand Down Expand Up @@ -77,8 +78,7 @@ def __init__(self, redis, name, timeout=None, sleep=0.1,
self.blocking = blocking
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.local = threading.local() if self.thread_local else dummy()
self.local.token = None
self.local = contextvars.ContextVar('token', default=None) if self.thread_local else dummy()
if self.timeout and self.sleep > self.timeout:
raise LockError("'sleep' must be less than 'timeout'")

Expand Down Expand Up @@ -113,7 +113,7 @@ async def acquire(self, blocking=None, blocking_timeout=None):
stop_trying_at = mod_time.time() + blocking_timeout
while True:
if await self.do_acquire(token):
self.local.token = token
self.local.set(token)
return True
if not blocking:
return False
Expand All @@ -131,10 +131,10 @@ async def do_acquire(self, token):

async def release(self):
"Releases the already acquired lock"
expected_token = self.local.token
expected_token = self.local.get()
if expected_token is None:
raise LockError("Cannot release an unlocked lock")
self.local.token = None
self.local.set(None)
await self.do_release(expected_token)

async def do_release(self, expected_token):
Expand All @@ -155,7 +155,7 @@ async def extend(self, additional_time):
``additional_time`` can be specified as an integer or a float, both
representing the number of seconds to add.
"""
if self.local.token is None:
if self.local.get() is None:
raise LockError("Cannot extend an unlocked lock")
if self.timeout is None:
raise LockError("Cannot extend a lock with no timeout")
Expand All @@ -165,7 +165,7 @@ async def do_extend(self, additional_time):
pipe = await self.redis.pipeline()
await pipe.watch(self.name)
lock_value = await pipe.get(self.name)
if lock_value != self.local.token:
if lock_value != self.local.get():
raise LockError("Cannot extend a lock that's no longer owned")
expiration = await pipe.pttl(self.name)
if expiration is None or expiration < 0:
Expand Down Expand Up @@ -246,7 +246,7 @@ async def do_release(self, expected_token):
async def do_extend(self, additional_time):
additional_time = int(additional_time * 1000)
if not bool(await self.lua_extend.execute(keys=[self.name],
args=[self.local.token, additional_time],
args=[self.local.get(), additional_time],
client=self.redis)):
raise LockError("Cannot extend a lock that's no longer owned")
return True
Expand Down Expand Up @@ -337,7 +337,7 @@ async def acquire(self, blocking=None, blocking_timeout=None):
if check_finished_at > stop_trying_at:
await self.do_release(token)
return False
self.local.token = token
self.local.set(token)
# validity time is considered to be the
# initial validity time minus the time elapsed during check
await self.do_extend(lock_acquired_at - check_finished_at)
Expand Down
13 changes: 10 additions & 3 deletions aredis/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import sys
from functools import wraps

from aredis.exceptions import (RedisClusterException,
ClusterDownError)
from aredis.exceptions import (ClusterDownError, RedisClusterException)

_C_EXTENSION_SPEEDUP = False
try:
Expand Down Expand Up @@ -56,7 +55,15 @@ class dummy(object):
"""
Instances of this class can be used as an attribute container.
"""
pass

def __init__(self):
self.token = None

def set(self, value):
self.token = value

def get(self):
return self.token


# ++++++++++ response callbacks ++++++++++++++
Expand Down
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r test_requirements.txt
hiredis
uvloop
contextvars
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,12 @@ def build_extension(self, ext):
ext_modules=[
Extension(name='aredis.speedups',
sources=['aredis/speedups.c']),
],
# The good news is that the standard library always
# takes the precedence over site packages,
# so even if a local contextvars module is installed,
# the one from the standard library will be used.
install_requires=[
'contextvars;python_version<"3.7"'
]
)
1 change: 1 addition & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mock
pytest
pytest-asyncio
contextvars
30 changes: 26 additions & 4 deletions tests/client/test_lock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import with_statement
import pytest

import asyncio
import time

import pytest

from aredis.exceptions import LockError, ResponseError
from aredis.lock import Lock, LuaLock

Expand All @@ -18,7 +21,7 @@ async def test_lock(self, r):
await r.flushdb()
lock = self.get_lock(r, 'foo')
assert await lock.acquire(blocking=False)
assert await r.get('foo') == lock.local.token
assert await r.get('foo') == lock.local.get()
assert await r.ttl('foo') == -1
await lock.release()
assert await r.get('foo') is None
Expand Down Expand Up @@ -63,7 +66,7 @@ async def test_context_manager(self, r):
# blocking_timeout prevents a deadlock if the lock can't be acquired
# for some reason
async with self.get_lock(r, 'foo', blocking_timeout=0.2) as lock:
assert await r.get('foo') == lock.local.token
assert await r.get('foo') == lock.local.get()
assert await r.get('foo') is None

@pytest.mark.asyncio()
Expand All @@ -87,7 +90,7 @@ async def test_releasing_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockError):
await lock.release()
# even though we errored, the token is still cleared
assert lock.local.token is None
assert lock.local.get() is None

@pytest.mark.asyncio()
async def test_extend_lock(self, r):
Expand Down Expand Up @@ -133,6 +136,23 @@ async def test_extending_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockError):
await lock.extend(10)

@pytest.mark.asyncio()
async def test_concurrent_lock_acquire(self, r):
lock = self.get_lock(r, 'test', timeout=1)

async def coro(lock):
is_error_raised = False
await lock.acquire(blocking=True)
await asyncio.sleep(1.5)
try:
await lock.release()
except LockError as exc:
is_error_raised = True
return is_error_raised

results = await asyncio.gather(coro(lock), coro(lock))
assert not (results[0] and results[1])


class TestLuaLock(TestLock):
lock_class = LuaLock
Expand Down Expand Up @@ -170,6 +190,7 @@ async def test_lua_compatible_server(self, r, monkeypatch):
@classmethod
def mock_register(cls, redis):
return

monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = r.lock('foo')
Expand All @@ -183,6 +204,7 @@ async def test_lua_unavailable(self, r, monkeypatch):
@classmethod
def mock_register(cls, redis):
raise ResponseError()

monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = r.lock('foo')
Expand Down
6 changes: 3 additions & 3 deletions tests/cluster/test_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def test_lock(self, r):
await r.flushdb()
lock = self.get_lock(r, 'foo', timeout=3)
assert await lock.acquire(blocking=False)
assert await r.get('foo') == lock.local.token
assert await r.get('foo') == lock.local.get()
assert await r.ttl('foo') == 3
await lock.release()
assert await r.get('foo') is None
Expand Down Expand Up @@ -61,7 +61,7 @@ async def test_context_manager(self, r):
# blocking_timeout prevents a deadlock if the lock can't be acquired
# for some reason
async with self.get_lock(r, 'foo', timeout=3, blocking_timeout=0.2) as lock:
assert await r.get('foo') == lock.local.token
assert await r.get('foo') == lock.local.get()
assert await r.get('foo') is None

@pytest.mark.asyncio()
Expand All @@ -85,7 +85,7 @@ async def test_releasing_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockError):
await lock.release()
# even though we errored, the token is still cleared
assert lock.local.token is None
assert lock.local.get() is None

@pytest.mark.asyncio()
async def test_extend_lock(self, r):
Expand Down

0 comments on commit 613dc13

Please sign in to comment.