diff --git a/CHANGES b/CHANGES index 09babb706f..e83660d6ac 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Add test and fix async HiredisParser when reading during a disconnect() (#2349) * Use hiredis-py pack_command if available. * Support `.unlink()` in ClusterPipeline * Simplify synchronous SocketBuffer state management diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2c75d4fcf1..862f6f096b 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -350,13 +350,14 @@ async def _readline(self) -> bytes: class HiredisParser(BaseParser): """Parser class for connections using Hiredis""" - __slots__ = BaseParser.__slots__ + ("_reader",) + __slots__ = BaseParser.__slots__ + ("_reader", "_connected") def __init__(self, socket_read_size: int): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) self._reader: Optional[hiredis.Reader] = None + self._connected: bool = False def on_connect(self, connection: "Connection"): self._stream = connection._reader @@ -369,13 +370,13 @@ def on_connect(self, connection: "Connection"): kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) + self._connected = True def on_disconnect(self): - self._stream = None - self._reader = None + self._connected = False async def can_read_destructive(self): - if not self._stream or not self._reader: + if not self._connected: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._reader.gets(): return True @@ -397,8 +398,10 @@ async def read_from_socket(self): async def read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, List[EncodableT]]: - if not self._stream or not self._reader: - self.on_disconnect() + # If `on_disconnect()` has been called, prohibit any more reads + # even if they could happen because data might be present. + # We still allow reads in progress to finish + if not self._connected: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None response = self._reader.gets() diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 8e4fdac309..1851ca9a76 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -10,12 +10,14 @@ from redis.asyncio.connection import ( BaseParser, Connection, + HiredisParser, PythonParser, UnixDomainSocketConnection, ) from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt from .compat import mock @@ -191,3 +193,83 @@ async def test_connection_parse_response_resume(r: redis.Redis): pytest.fail("didn't receive a response") assert response assert i > 0 + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize( + "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] +) +async def test_connection_disconect_race(parser_class): + """ + This test reproduces the case in issue #2349 + where a connection is closed while the parser is reading to feed the + internal buffer.The stream `read()` will succeed, but when it returns, + another task has already called `disconnect()` and is waiting for + close to finish. When we attempts to feed the buffer, we will fail + since the buffer is no longer there. + + This test verifies that a read in progress can finish even + if the `disconnect()` method is called. + """ + if parser_class == PythonParser: + pytest.xfail("doesn't work yet with PythonParser") + if parser_class == HiredisParser and not HIREDIS_AVAILABLE: + pytest.skip("Hiredis not available") + + args = {} + args["parser_class"] = parser_class + + conn = Connection(**args) + + cond = asyncio.Condition() + # 0 == initial + # 1 == reader is reading + # 2 == closer has closed and is waiting for close to finish + state = 0 + + # Mock read function, which wait for a close to happen before returning + # Can either be invoked as two `read()` calls (HiredisParser) + # or as a `readline()` followed by `readexact()` (PythonParser) + chunks = [b"$13\r\n", b"Hello, World!\r\n"] + + async def read(_=None): + nonlocal state + async with cond: + if state == 0: + state = 1 # we are reading + cond.notify() + # wait until the closing task has done + await cond.wait_for(lambda: state == 2) + return chunks.pop(0) + + # function closes the connection while reader is still blocked reading + async def do_close(): + nonlocal state + async with cond: + await cond.wait_for(lambda: state == 1) + state = 2 + cond.notify() + await conn.disconnect() + + async def do_read(): + return await conn.read_response() + + reader = mock.AsyncMock() + writer = mock.AsyncMock() + writer.transport = mock.Mock() + writer.transport.get_extra_info.side_effect = None + + # for HiredisParser + reader.read.side_effect = read + # for PythonParser + reader.readline.side_effect = read + reader.readexactly.side_effect = read + + async def open_connection(*args, **kwargs): + return reader, writer + + with patch.object(asyncio, "open_connection", open_connection): + await conn.connect() + + vals = await asyncio.gather(do_read(), do_close()) + assert vals == [b"Hello, World!", None]