Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixbug threading local in coroutine #120

Merged
merged 5 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions aredis/lock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import threading
import time as mod_time
import uuid
import warnings

import contextvars

from aredis.connection import ClusterConnection
from aredis.exceptions import LockError, WatchError
from aredis.utils import b, dummy
Expand Down Expand Up @@ -77,8 +78,7 @@ def __init__(self, redis, name, timeout=None, sleep=0.1,
self.blocking = blocking
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.local = threading.local() if self.thread_local else dummy()
self.local.token = None
self.local = contextvars.ContextVar('token', default=None) if self.thread_local else dummy()
Copy link

@eirikur-grid eirikur-grid Aug 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for contextvars has the following disclaimer:

Important: Context Variables should be created at the top module level and never in closures. Context objects hold strong references to context variables which prevents context variables from being properly garbage collected.

Declaring the context variable in the __init__ method may result in a memory leak if my understanding is correct. I propose to declare it outside of the class definition. For example:

token_var = contextvars.ContextVar('token', default=None)


class Lock:
  def __init__(...):
  self.local = token_var if self.thread_local else dummy()
  ...

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, moving the contextvar definition to the module level will cause tokens to bleed between lock instances created by the same co-routine. That's a bad idea. Sorry, you can ignore my proposal.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, moving the contextvar definition to the module level will cause tokens to bleed between lock instances created by the same co-routine. That's a bad idea. Sorry, you can ignore my proposal.

Sorry, github didn't send me a msg and i saw your code review so late.

if self.timeout and self.sleep > self.timeout:
raise LockError("'sleep' must be less than 'timeout'")

Expand Down Expand Up @@ -113,7 +113,7 @@ async def acquire(self, blocking=None, blocking_timeout=None):
stop_trying_at = mod_time.time() + blocking_timeout
while True:
if await self.do_acquire(token):
self.local.token = token
self.local.set(token)
return True
if not blocking:
return False
Expand All @@ -131,10 +131,10 @@ async def do_acquire(self, token):

async def release(self):
"Releases the already acquired lock"
expected_token = self.local.token
expected_token = self.local.get()
if expected_token is None:
raise LockError("Cannot release an unlocked lock")
self.local.token = None
self.local.set(None)
await self.do_release(expected_token)

async def do_release(self, expected_token):
Expand All @@ -155,7 +155,7 @@ async def extend(self, additional_time):
``additional_time`` can be specified as an integer or a float, both
representing the number of seconds to add.
"""
if self.local.token is None:
if self.local.get() is None:
raise LockError("Cannot extend an unlocked lock")
if self.timeout is None:
raise LockError("Cannot extend a lock with no timeout")
Expand All @@ -165,7 +165,7 @@ async def do_extend(self, additional_time):
pipe = await self.redis.pipeline()
await pipe.watch(self.name)
lock_value = await pipe.get(self.name)
if lock_value != self.local.token:
if lock_value != self.local.get():
raise LockError("Cannot extend a lock that's no longer owned")
expiration = await pipe.pttl(self.name)
if expiration is None or expiration < 0:
Expand Down Expand Up @@ -246,7 +246,7 @@ async def do_release(self, expected_token):
async def do_extend(self, additional_time):
additional_time = int(additional_time * 1000)
if not bool(await self.lua_extend.execute(keys=[self.name],
args=[self.local.token, additional_time],
args=[self.local.get(), additional_time],
client=self.redis)):
raise LockError("Cannot extend a lock that's no longer owned")
return True
Expand Down Expand Up @@ -337,7 +337,7 @@ async def acquire(self, blocking=None, blocking_timeout=None):
if check_finished_at > stop_trying_at:
await self.do_release(token)
return False
self.local.token = token
self.local.set(token)
# validity time is considered to be the
# initial validity time minus the time elapsed during check
await self.do_extend(lock_acquired_at - check_finished_at)
Expand Down
13 changes: 10 additions & 3 deletions aredis/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import sys
from functools import wraps

from aredis.exceptions import (RedisClusterException,
ClusterDownError)
from aredis.exceptions import (ClusterDownError, RedisClusterException)

_C_EXTENSION_SPEEDUP = False
try:
Expand Down Expand Up @@ -56,7 +55,15 @@ class dummy(object):
"""
Instances of this class can be used as an attribute container.
"""
pass

def __init__(self):
self.token = None

def set(self, value):
self.token = value

def get(self):
return self.token


# ++++++++++ response callbacks ++++++++++++++
Expand Down
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r test_requirements.txt
hiredis
uvloop
contextvars
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,12 @@ def build_extension(self, ext):
ext_modules=[
Extension(name='aredis.speedups',
sources=['aredis/speedups.c']),
],
# The good news is that the standard library always
# takes the precedence over site packages,
# so even if a local contextvars module is installed,
# the one from the standard library will be used.
install_requires=[
'contextvars;python_version<"3.7"'
]
)
1 change: 1 addition & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mock
pytest
pytest-asyncio
contextvars
30 changes: 26 additions & 4 deletions tests/client/test_lock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import with_statement
import pytest

import asyncio
import time

import pytest

from aredis.exceptions import LockError, ResponseError
from aredis.lock import Lock, LuaLock

Expand All @@ -18,7 +21,7 @@ async def test_lock(self, r):
await r.flushdb()
lock = self.get_lock(r, 'foo')
assert await lock.acquire(blocking=False)
assert await r.get('foo') == lock.local.token
assert await r.get('foo') == lock.local.get()
assert await r.ttl('foo') == -1
await lock.release()
assert await r.get('foo') is None
Expand Down Expand Up @@ -63,7 +66,7 @@ async def test_context_manager(self, r):
# blocking_timeout prevents a deadlock if the lock can't be acquired
# for some reason
async with self.get_lock(r, 'foo', blocking_timeout=0.2) as lock:
assert await r.get('foo') == lock.local.token
assert await r.get('foo') == lock.local.get()
assert await r.get('foo') is None

@pytest.mark.asyncio()
Expand All @@ -87,7 +90,7 @@ async def test_releasing_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockError):
await lock.release()
# even though we errored, the token is still cleared
assert lock.local.token is None
assert lock.local.get() is None

@pytest.mark.asyncio()
async def test_extend_lock(self, r):
Expand Down Expand Up @@ -133,6 +136,23 @@ async def test_extending_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockError):
await lock.extend(10)

@pytest.mark.asyncio()
async def test_concurrent_lock_acquire(self, r):
lock = self.get_lock(r, 'test', timeout=1)

async def coro(lock):
is_error_raised = False
await lock.acquire(blocking=True)
await asyncio.sleep(1.5)
try:
await lock.release()
except LockError as exc:
is_error_raised = True
return is_error_raised

results = await asyncio.gather(coro(lock), coro(lock))
assert not (results[0] and results[1])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion fails when I run this test locally. I believe the assertion is incorrect.

Both tasks/coroutines sleep for longer than the lifetime of the lock. As a consequence, both will encounter a LockError when attempting to release the lock. Hence, the correct assertion would be assert (result[0] and result[1])

Copy link
Owner Author

@NoneGG NoneGG Dec 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eirikur-grid

This assertion fails when I run this test locally. I believe the assertion is incorrect.

Both tasks/coroutines sleep for longer than the lifetime of the lock. As a consequence, both will encounter a LockError when attempting to release the lock. Hence, the correct assertion would be assert (result[0] and result[1])

I think you are right, the test result is random which is not appropriate



class TestLuaLock(TestLock):
lock_class = LuaLock
Expand Down Expand Up @@ -170,6 +190,7 @@ async def test_lua_compatible_server(self, r, monkeypatch):
@classmethod
def mock_register(cls, redis):
return

monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = r.lock('foo')
Expand All @@ -183,6 +204,7 @@ async def test_lua_unavailable(self, r, monkeypatch):
@classmethod
def mock_register(cls, redis):
raise ResponseError()

monkeypatch.setattr(LuaLock, 'register_scripts', mock_register)
try:
lock = r.lock('foo')
Expand Down
6 changes: 3 additions & 3 deletions tests/cluster/test_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def test_lock(self, r):
await r.flushdb()
lock = self.get_lock(r, 'foo', timeout=3)
assert await lock.acquire(blocking=False)
assert await r.get('foo') == lock.local.token
assert await r.get('foo') == lock.local.get()
assert await r.ttl('foo') == 3
await lock.release()
assert await r.get('foo') is None
Expand Down Expand Up @@ -61,7 +61,7 @@ async def test_context_manager(self, r):
# blocking_timeout prevents a deadlock if the lock can't be acquired
# for some reason
async with self.get_lock(r, 'foo', timeout=3, blocking_timeout=0.2) as lock:
assert await r.get('foo') == lock.local.token
assert await r.get('foo') == lock.local.get()
assert await r.get('foo') is None

@pytest.mark.asyncio()
Expand All @@ -85,7 +85,7 @@ async def test_releasing_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockError):
await lock.release()
# even though we errored, the token is still cleared
assert lock.local.token is None
assert lock.local.get() is None

@pytest.mark.asyncio()
async def test_extend_lock(self, r):
Expand Down