From 48159dd11ecd3f303f5fd78361731adc01aee843 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Mon, 4 Apr 2022 04:51:52 +0000 Subject: [PATCH 1/3] fix: Replace aioredis with redis.asyncio --- setup.cfg | 6 +- src/ai/backend/common/events.py | 12 +- .../common/{redis.py => redis_helper.py} | 115 +++++++++--------- src/ai/backend/common/types.py | 10 +- tests/redis/test_connect.py | 25 ++-- tests/redis/test_list.py | 41 +++---- tests/redis/test_pipeline.py | 28 ++--- tests/redis/test_pubsub.py | 36 +++--- tests/redis/test_stream.py | 43 ++++--- tests/redis/utils.py | 12 +- tests/test_events.py | 6 +- 11 files changed, 167 insertions(+), 167 deletions(-) rename src/ai/backend/common/{redis.py => redis_helper.py} (79%) diff --git a/setup.cfg b/setup.cfg index e2c912d4..d418f344 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,8 @@ install_requires = pyzmq>=22.1.0 aiohttp>=3.8.0 aiodns>=3.0 - aioredis[hiredis]~=2.0.1 + redis~=4.2.1 + hiredis~=2.0 aiotools>=1.5.5 async-timeout~=4.0.1 asyncudp>=0.4 @@ -87,8 +88,9 @@ lint = typecheck = mypy>=0.942 types-python-dateutil - types-toml + types-redis types-setuptools + types-toml dev = monitor = backend.ai-monitor-sentry>=0.2.1 diff --git a/src/ai/backend/common/events.py b/src/ai/backend/common/events.py index 40162ac2..b92bc253 100644 --- a/src/ai/backend/common/events.py +++ b/src/ai/backend/common/events.py @@ -36,7 +36,7 @@ from aiotools.taskgroup import PersistentTaskGroup import attr -from . import msgpack, redis +from . import msgpack, redis_helper from .logging import BraceStyleAdapter from .types import ( EtcdRedisConfig, @@ -650,7 +650,7 @@ def __init__( _redis_config = redis_config.copy() if service_name: _redis_config['service_name'] = service_name - self.redis_client = redis.get_redis_object(_redis_config, db=db) + self.redis_client = redis_helper.get_redis_object(_redis_config, db=db) self._log_events = log_events self._closed = False self.consumers = defaultdict(set) @@ -778,7 +778,7 @@ async def dispatch_subscribers( await asyncio.sleep(0) async def _consume_loop(self) -> None: - async with aclosing(redis.read_stream_by_group( + async with aclosing(redis_helper.read_stream_by_group( self.redis_client, self._stream_key, self._consumer_group, @@ -801,7 +801,7 @@ async def _consume_loop(self) -> None: log.exception('EventDispatcher.consume(): unexpected-error') async def _subscribe_loop(self) -> None: - async with aclosing(redis.read_stream( + async with aclosing(redis_helper.read_stream( self.redis_client, self._stream_key, )) as agen: @@ -839,7 +839,7 @@ def __init__( if service_name: _redis_config['service_name'] = service_name self._closed = False - self.redis_client = redis.get_redis_object(_redis_config, db=db) + self.redis_client = redis_helper.get_redis_object(_redis_config, db=db) self._log_events = log_events self._stream_key = stream_key @@ -863,7 +863,7 @@ async def produce_event( b'source': source.encode(), b'args': msgpack.packb(event.serialize()), } - await redis.execute( + await redis_helper.execute( self.redis_client, lambda r: r.xadd(self._stream_key, raw_event), # type: ignore # aio-libs/aioredis-py#1182 ) diff --git a/src/ai/backend/common/redis.py b/src/ai/backend/common/redis_helper.py similarity index 79% rename from src/ai/backend/common/redis.py rename to src/ai/backend/common/redis_helper.py index ee4ebc53..220c342a 100644 --- a/src/ai/backend/common/redis.py +++ b/src/ai/backend/common/redis_helper.py @@ -17,10 +17,11 @@ Union, ) -import aioredis -import aioredis.client -import aioredis.sentinel -import aioredis.exceptions +import redis.asyncio +import redis.asyncio.client +import redis.asyncio.sentinel +import redis.client +import redis.exceptions import yarl from .types import EtcdRedisConfig, RedisConnectionInfo @@ -76,7 +77,7 @@ def _parse_stream_msg_id(msg_id: bytes) -> Tuple[int, int]: async def subscribe( - channel: aioredis.client.PubSub, + channel: redis.asyncio.client.PubSub, *, reconnect_poll_interval: float = 0.3, ) -> AsyncIterator[Any]: @@ -88,7 +89,7 @@ async def _reset_chan(): channel.connection = None try: await channel.ping() - except aioredis.exceptions.ConnectionError: + except redis.exceptions.ConnectionError: pass else: assert channel.connection is not None @@ -102,24 +103,24 @@ async def _reset_chan(): if message is not None: yield message["data"] except ( - aioredis.exceptions.ConnectionError, - aioredis.sentinel.MasterNotFoundError, - aioredis.sentinel.SlaveNotFoundError, - aioredis.exceptions.ReadOnlyError, - aioredis.exceptions.ResponseError, + redis.exceptions.ConnectionError, + redis.asyncio.sentinel.MasterNotFoundError, + redis.asyncio.sentinel.SlaveNotFoundError, + redis.exceptions.ReadOnlyError, + redis.exceptions.ResponseError, ConnectionResetError, ConnectionNotAvailable, ): await asyncio.sleep(reconnect_poll_interval) await _reset_chan() continue - except aioredis.exceptions.ResponseError as e: + except redis.exceptions.ResponseError as e: if e.args[0].startswith("NOREPLICAS "): await asyncio.sleep(reconnect_poll_interval) await _reset_chan() continue raise - except (TimeoutError, asyncio.TimeoutError): + except (redis.exceptions.TimeoutError, asyncio.TimeoutError): continue except asyncio.CancelledError: raise @@ -128,7 +129,7 @@ async def _reset_chan(): async def blpop( - redis: RedisConnectionInfo | aioredis.Redis | aioredis.sentinel.Sentinel, + redis_connector: RedisConnectionInfo | redis.asyncio.Redis | redis.asyncio.Sentinel, key: str, *, service_name: str = None, @@ -142,18 +143,18 @@ async def blpop( **_default_conn_opts, 'socket_connect_timeout': reconnect_poll_interval, } - if isinstance(redis, RedisConnectionInfo): - redis_client = redis.client - service_name = service_name or redis.service_name + if isinstance(redis_connector, RedisConnectionInfo): + redis_client = redis_connector.client + service_name = service_name or redis_connector.service_name else: - redis_client = redis + redis_client = redis_connector - if isinstance(redis_client, aioredis.sentinel.Sentinel): + if isinstance(redis_client, redis.asyncio.Sentinel): assert service_name is not None r = redis_client.master_for( service_name, - redis_class=aioredis.Redis, - connection_pool_class=aioredis.sentinel.SentinelConnectionPool, + redis_class=redis.asyncio.Redis, + connection_pool_class=redis.asyncio.SentinelConnectionPool, **_conn_opts, ) else: @@ -165,20 +166,20 @@ async def blpop( continue yield raw_msg[1] except ( - aioredis.exceptions.ConnectionError, - aioredis.sentinel.MasterNotFoundError, - aioredis.exceptions.ReadOnlyError, - aioredis.exceptions.ResponseError, + redis.exceptions.ConnectionError, + redis.asyncio.sentinel.MasterNotFoundError, + redis.exceptions.ReadOnlyError, + redis.exceptions.ResponseError, ConnectionResetError, ): await asyncio.sleep(reconnect_poll_interval) continue - except aioredis.exceptions.ResponseError as e: + except redis.exceptions.ResponseError as e: if e.args[0].startswith("NOREPLICAS "): await asyncio.sleep(reconnect_poll_interval) continue raise - except (TimeoutError, asyncio.TimeoutError): + except (redis.exceptions.TimeoutError, asyncio.TimeoutError): continue except asyncio.CancelledError: raise @@ -187,8 +188,8 @@ async def blpop( async def execute( - redis: RedisConnectionInfo | aioredis.Redis | aioredis.sentinel.Sentinel, - func: Callable[[aioredis.Redis], Awaitable[Any]], + redis_connector: RedisConnectionInfo | redis.asyncio.Redis | redis.asyncio.Sentinel, + func: Callable[[redis.asyncio.Redis], Awaitable[Any]], *, service_name: str = None, read_only: bool = False, @@ -206,26 +207,26 @@ async def execute( **_default_conn_opts, 'socket_connect_timeout': reconnect_poll_interval, } - if isinstance(redis, RedisConnectionInfo): - redis_client = redis.client - service_name = service_name or redis.service_name + if isinstance(redis_connector, RedisConnectionInfo): + redis_client = redis_connector.client + service_name = service_name or redis_connector.service_name else: - redis_client = redis + redis_client = redis_connector - if isinstance(redis_client, aioredis.sentinel.Sentinel): + if isinstance(redis_client, redis.asyncio.Sentinel): assert service_name is not None if read_only: r = redis_client.slave_for( service_name, - redis_class=aioredis.Redis, - connection_pool_class=aioredis.sentinel.SentinelConnectionPool, + redis_class=redis.asyncio.Redis, + connection_pool_class=redis.asyncio.SentinelConnectionPool, **_conn_opts, ) else: r = redis_client.master_for( service_name, - redis_class=aioredis.Redis, - connection_pool_class=aioredis.sentinel.SentinelConnectionPool, + redis_class=redis.asyncio.Redis, + connection_pool_class=redis.asyncio.SentinelConnectionPool, **_conn_opts, ) else: @@ -238,14 +239,14 @@ async def execute( else: raise TypeError('The func must be a function or a coroutinefunction ' 'with no arguments.') - if isinstance(aw_or_pipe, aioredis.client.Pipeline): + if isinstance(aw_or_pipe, redis.asyncio.client.Pipeline): result = await aw_or_pipe.execute() elif inspect.isawaitable(aw_or_pipe): result = await aw_or_pipe else: raise TypeError('The return value must be an awaitable' 'or aioredis.commands.Pipeline object') - if isinstance(result, aioredis.client.Pipeline): + if isinstance(result, redis.asyncio.client.Pipeline): # This happens when func is an async function that returns a pipeline. result = await result.execute() if encoding: @@ -259,20 +260,20 @@ async def execute( else: return result except ( - aioredis.exceptions.ConnectionError, - aioredis.sentinel.MasterNotFoundError, - aioredis.sentinel.SlaveNotFoundError, - aioredis.exceptions.ReadOnlyError, + redis.exceptions.ConnectionError, + redis.asyncio.sentinel.MasterNotFoundError, + redis.asyncio.sentinel.SlaveNotFoundError, + redis.exceptions.ReadOnlyError, ConnectionResetError, ): await asyncio.sleep(reconnect_poll_interval) continue - except aioredis.exceptions.ResponseError as e: + except redis.exceptions.ResponseError as e: if "NOREPLICAS" in e.args[0]: await asyncio.sleep(reconnect_poll_interval) continue raise - except (TimeoutError, asyncio.TimeoutError): + except (redis.exceptions.TimeoutError, asyncio.TimeoutError): continue except asyncio.CancelledError: raise @@ -281,7 +282,7 @@ async def execute( async def execute_script( - redis: RedisConnectionInfo | aioredis.Redis | aioredis.sentinel.Sentinel, + redis_connector: RedisConnectionInfo | redis.asyncio.Redis | redis.asyncio.Sentinel, script_id: str, script: str, keys: Sequence[str], @@ -303,20 +304,20 @@ async def execute_script( script_hash = _scripts.get(script_id, 'x') while True: try: - ret = await execute(redis, lambda r: r.evalsha( + ret = await execute(redis_connector, lambda r: r.evalsha( script_hash, len(keys), *keys, *args, )) break - except aioredis.exceptions.NoScriptError: + except redis.exceptions.NoScriptError: # Redis may have been restarted. - script_hash = await execute(redis, lambda r: r.script_load(script)) + script_hash = await execute(redis_connector, lambda r: r.script_load(script)) _scripts[script_id] = script_hash - except aioredis.exceptions.ResponseError as e: + except redis.exceptions.ResponseError as e: if 'NOSCRIPT' in e.args[0]: # Redis may have been restarted. - script_hash = await execute(redis, lambda r: r.script_load(script)) + script_hash = await execute(redis_connector, lambda r: r.script_load(script)) _scripts[script_id] = script_hash else: raise @@ -393,7 +394,7 @@ async def read_stream_by_group( autoclaim_start_id, ), ) - for msg_id, msg_data in aioredis.client.parse_stream_list(reply[1]): + for msg_id, msg_data in redis.client.parse_stream_list(reply[1]): # type: ignore messages.append((msg_id, msg_data)) if reply[0] == b'0-0': break @@ -422,7 +423,7 @@ async def read_stream_by_group( yield msg_id, msg_data except asyncio.CancelledError: raise - except aioredis.exceptions.ResponseError as e: + except redis.exceptions.ResponseError as e: if e.args[0].startswith("NOGROUP "): try: await execute( @@ -434,7 +435,7 @@ async def read_stream_by_group( mkstream=True, ), ) - except aioredis.exceptions.ResponseError as e: + except redis.exceptions.ResponseError as e: if e.args[0].startswith("BUSYGROUP "): pass else: @@ -456,7 +457,7 @@ def get_redis_object( sentinel_addresses = _sentinel_addresses assert redis_config.get('service_name') is not None - sentinel = aioredis.sentinel.Sentinel( + sentinel = redis.asyncio.Sentinel( [(str(host), port) for host, port in sentinel_addresses], password=redis_config.get('password'), db=str(db), @@ -478,6 +479,6 @@ def get_redis_object( .with_password(redis_config.get('password')) / str(db) ) return RedisConnectionInfo( - client=aioredis.Redis.from_url(str(url), **kwargs), + client=redis.asyncio.Redis.from_url(str(url), **kwargs), service_name=None, ) diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index f47c085c..a68a33ae 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -31,9 +31,9 @@ ) import uuid -import aioredis -import aioredis.client -import aioredis.sentinel +import redis.asyncio +import redis.asyncio.client +import redis.asyncio.sentinel import attr import trafaret as t import typeguard @@ -840,9 +840,9 @@ class EtcdRedisConfig(TypedDict, total=False): @attr.s(auto_attribs=True) class RedisConnectionInfo: - client: aioredis.Redis | aioredis.sentinel.Sentinel + client: redis.asyncio.Redis | redis.asyncio.Sentinel service_name: Optional[str] async def close(self) -> None: - if isinstance(self.client, aioredis.Redis): + if isinstance(self.client, redis.asyncio.Redis): await self.client.close() diff --git a/tests/redis/test_connect.py b/tests/redis/test_connect.py index bd076322..9dbab601 100644 --- a/tests/redis/test_connect.py +++ b/tests/redis/test_connect.py @@ -3,17 +3,16 @@ import asyncio from typing import TYPE_CHECKING -import aioredis -import aioredis.client -import aioredis.exceptions -import aioredis.sentinel +import redis.asyncio +import redis.asyncio.client +import redis.asyncio.sentinel import aiotools import pytest from .types import RedisClusterInfo from .utils import interrupt, with_timeout -from ai.backend.common import redis, validators as tx +from ai.backend.common import redis_helper, validators as tx if TYPE_CHECKING: from typing import Any @@ -21,7 +20,7 @@ @pytest.mark.asyncio async def test_connect(redis_container: str) -> None: - r = aioredis.from_url( + r = redis.asyncio.from_url( url='redis://localhost:9379', socket_timeout=0.5, ) @@ -32,13 +31,13 @@ async def test_connect(redis_container: str) -> None: @pytest.mark.asyncio async def test_instantiate_redisconninfo() -> None: sentinels = '127.0.0.1:26379,127.0.0.1:26380,127.0.0.1:26381' - r1 = redis.get_redis_object({ + r1 = redis_helper.get_redis_object({ 'sentinel': sentinels, 'service_name': 'mymaster', 'password': 'develove', }) - assert isinstance(r1.client, aioredis.sentinel.Sentinel) + assert isinstance(r1.client, redis.asyncio.Sentinel) for i in range(3): assert r1.client.sentinels[i].connection_pool.connection_kwargs['host'] == '127.0.0.1' @@ -46,13 +45,13 @@ async def test_instantiate_redisconninfo() -> None: assert r1.client.sentinels[i].connection_pool.connection_kwargs['db'] == 0 parsed_addresses: Any = tx.DelimiterSeperatedList(tx.HostPortPair).check_and_return(sentinels) - r2 = redis.get_redis_object({ + r2 = redis_helper.get_redis_object({ 'sentinel': parsed_addresses, 'service_name': 'mymaster', 'password': 'develove', }) - assert isinstance(r2.client, aioredis.sentinel.Sentinel) + assert isinstance(r2.client, redis.asyncio.Sentinel) for i in range(3): assert r2.client.sentinels[i].connection_pool.connection_kwargs['host'] == '127.0.0.1' @@ -77,7 +76,7 @@ async def control_interrupt() -> None: do_unpause.set() await unpaused.wait() - s = aioredis.sentinel.Sentinel( + s = redis.asyncio.Sentinel( redis_cluster.sentinel_addrs, password='develove', socket_timeout=0.5, @@ -100,13 +99,13 @@ async def control_interrupt() -> None: try: master_addr = await s.discover_master('mymaster') print("MASTER", master_addr) - except aioredis.sentinel.MasterNotFoundError: + except redis.asyncio.sentinel.MasterNotFoundError: print("MASTER (not found)") try: slave_addrs = await s.discover_slaves('mymaster') print("SLAVE", slave_addrs) slave = s.slave_for('mymaster', db=9) await slave.ping() - except aioredis.sentinel.SlaveNotFoundError: + except redis.asyncio.sentinel.SlaveNotFoundError: print("SLAVE (not found)") await asyncio.sleep(1) diff --git a/tests/redis/test_list.py b/tests/redis/test_list.py index 594cdfd3..56b25bbf 100644 --- a/tests/redis/test_list.py +++ b/tests/redis/test_list.py @@ -5,14 +5,13 @@ List, ) -import aioredis -import aioredis.client -import aioredis.exceptions -import aioredis.sentinel +import redis.asyncio +import redis.asyncio.client +import redis.exceptions import aiotools import pytest -from ai.backend.common import redis +from ai.backend.common import redis_helper from ai.backend.common.types import RedisConnectionInfo from .docker import DockerRedisNode @@ -34,7 +33,7 @@ async def test_blist(redis_container: str, disruption_method: str) -> None: async def pop(r: RedisConnectionInfo, key: str) -> None: try: async with aiotools.aclosing( - redis.blpop(r, key, reconnect_poll_interval=0.3), + redis_helper.blpop(r, key, reconnect_poll_interval=0.3), ) as agen: async for raw_msg in agen: msg = raw_msg.decode() @@ -43,10 +42,10 @@ async def pop(r: RedisConnectionInfo, key: str) -> None: pass r = RedisConnectionInfo( - aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + redis.asyncio.from_url(url='redis://localhost:9379', socket_timeout=0.5), service_name=None, ) - assert isinstance(r.client, aioredis.Redis) + assert isinstance(r.client, redis.asyncio.Redis) await r.client.delete("bl1") pop_task = asyncio.create_task(pop(r, "bl1")) @@ -68,7 +67,7 @@ async def pop(r: RedisConnectionInfo, key: str) -> None: for i in range(5): # The Redis server is dead temporarily... if disruption_method == 'stop': - with pytest.raises(aioredis.exceptions.ConnectionError): + with pytest.raises(redis.exceptions.ConnectionError): await r.client.rpush("bl1", str(5 + i)) elif disruption_method == 'pause': with pytest.raises(asyncio.TimeoutError): @@ -107,7 +106,7 @@ async def test_blist_with_retrying_rpush(redis_container: str, disruption_method async def pop(r: RedisConnectionInfo, key: str) -> None: try: async with aiotools.aclosing( - redis.blpop(r, key, reconnect_poll_interval=0.3), + redis_helper.blpop(r, key, reconnect_poll_interval=0.3), ) as agen: async for raw_msg in agen: msg = raw_msg.decode() @@ -116,10 +115,10 @@ async def pop(r: RedisConnectionInfo, key: str) -> None: pass r = RedisConnectionInfo( - aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + redis.asyncio.from_url(url='redis://localhost:9379', socket_timeout=0.5), service_name=None, ) - assert isinstance(r.client, aioredis.Redis) + assert isinstance(r.client, redis.asyncio.Redis) await r.client.delete("bl1") pop_task = asyncio.create_task(pop(r, "bl1")) @@ -134,7 +133,7 @@ async def pop(r: RedisConnectionInfo, key: str) -> None: await asyncio.sleep(0) for i in range(5): - await redis.execute(r, lambda r: r.rpush("bl1", str(i))) + await redis_helper.execute(r, lambda r: r.rpush("bl1", str(i))) await asyncio.sleep(0.1) do_pause.set() await paused.wait() @@ -145,13 +144,13 @@ async def wakeup(): wakeup_task = asyncio.create_task(wakeup()) for i in range(5): - await redis.execute(r, lambda r: r.rpush("bl1", str(5 + i))) + await redis_helper.execute(r, lambda r: r.rpush("bl1", str(5 + i))) await asyncio.sleep(0.1) await wakeup_task await unpaused.wait() for i in range(5): - await redis.execute(r, lambda r: r.rpush("bl1", str(10 + i))) + await redis_helper.execute(r, lambda r: r.rpush("bl1", str(10 + i))) await asyncio.sleep(0.1) await interrupt_task @@ -183,7 +182,7 @@ async def test_blist_cluster_sentinel( async def pop(s: RedisConnectionInfo, key: str) -> None: try: async with aiotools.aclosing( - redis.blpop( + redis_helper.blpop( s, key, reconnect_poll_interval=0.3, service_name="mymaster", @@ -196,14 +195,14 @@ async def pop(s: RedisConnectionInfo, key: str) -> None: pass s = RedisConnectionInfo( - aioredis.sentinel.Sentinel( + redis.asyncio.Sentinel( redis_cluster.sentinel_addrs, password='develove', socket_timeout=0.5, ), service_name='mymaster', ) - await redis.execute(s, lambda r: r.delete("bl1")) + await redis_helper.execute(s, lambda r: r.delete("bl1")) pop_task = asyncio.create_task(pop(s, "bl1")) interrupt_task = asyncio.create_task(interrupt( @@ -218,7 +217,7 @@ async def pop(s: RedisConnectionInfo, key: str) -> None: await asyncio.sleep(0) for i in range(5): - await redis.execute( + await redis_helper.execute( s, lambda r: r.rpush("bl1", str(i)), service_name="mymaster", @@ -233,7 +232,7 @@ async def wakeup(): wakeup_task = asyncio.create_task(wakeup()) for i in range(5): - await redis.execute( + await redis_helper.execute( s, lambda r: r.rpush("bl1", str(5 + i)), service_name="mymaster", @@ -243,7 +242,7 @@ async def wakeup(): await unpaused.wait() for i in range(5): - await redis.execute( + await redis_helper.execute( s, lambda r: r.rpush("bl1", str(10 + i)), service_name="mymaster", diff --git a/tests/redis/test_pipeline.py b/tests/redis/test_pipeline.py index 570b8132..4174a535 100644 --- a/tests/redis/test_pipeline.py +++ b/tests/redis/test_pipeline.py @@ -2,12 +2,12 @@ from unittest import mock -import aioredis -import aioredis.client -import aioredis.sentinel +import redis.asyncio +import redis.asyncio.client +import redis.asyncio.sentinel import pytest -from ai.backend.common.redis import execute +from ai.backend.common.redis_helper import execute from ai.backend.common.types import RedisConnectionInfo from .types import RedisClusterInfo @@ -17,11 +17,11 @@ @pytest.mark.asyncio async def test_pipeline_single_instance(redis_container: str) -> None: rconn = RedisConnectionInfo( - aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + redis.asyncio.from_url(url='redis://localhost:9379', socket_timeout=0.5), service_name=None, ) - def _build_pipeline(r: aioredis.Redis) -> aioredis.client.Pipeline: + def _build_pipeline(r: redis.asyncio.Redis) -> redis.asyncio.client.Pipeline: pipe = r.pipeline(transaction=False) pipe.set("xyz", "123") pipe.incr("xyz") @@ -34,7 +34,7 @@ def _build_pipeline(r: aioredis.Redis) -> aioredis.client.Pipeline: actual_value = await execute(rconn, lambda r: r.get("xyz")) assert actual_value == b"124" - async def _build_pipeline_async(r: aioredis.Redis) -> aioredis.client.Pipeline: + async def _build_pipeline_async(r: redis.asyncio.Redis) -> redis.asyncio.client.Pipeline: pipe = r.pipeline(transaction=False) pipe.set("abc", "123") pipe.incr("abc") @@ -52,19 +52,19 @@ async def _build_pipeline_async(r: aioredis.Redis) -> aioredis.client.Pipeline: @pytest.mark.asyncio async def test_pipeline_single_instance_retries(redis_container: str) -> None: rconn = RedisConnectionInfo( - aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + redis.asyncio.from_url(url='redis://localhost:9379', socket_timeout=0.5), service_name=None, ) build_count = 0 patcher = mock.patch( - 'aioredis.client.Pipeline._execute_pipeline', + 'redis.asyncio.client.Pipeline._execute_pipeline', side_effect=[ConnectionResetError, ConnectionResetError, mock.DEFAULT], ) patcher.start() - def _build_pipeline(r: aioredis.Redis) -> aioredis.client.Pipeline: + def _build_pipeline(r: redis.asyncio.Redis) -> redis.asyncio.client.Pipeline: nonlocal build_count, patcher build_count += 1 if build_count == 3: @@ -86,12 +86,12 @@ def _build_pipeline(r: aioredis.Redis) -> aioredis.client.Pipeline: build_count = 0 patcher = mock.patch( - 'aioredis.client.Pipeline._execute_pipeline', + 'redis.asyncio.client.Pipeline._execute_pipeline', side_effect=[ConnectionResetError, ConnectionResetError, mock.DEFAULT], ) patcher.start() - async def _build_pipeline_async(r: aioredis.Redis) -> aioredis.client.Pipeline: + async def _build_pipeline_async(r: redis.asyncio.Redis) -> redis.asyncio.client.Pipeline: nonlocal build_count, patcher build_count += 1 if build_count == 3: @@ -115,7 +115,7 @@ async def _build_pipeline_async(r: aioredis.Redis) -> aioredis.client.Pipeline: @pytest.mark.asyncio async def test_pipeline_sentinel_cluster(redis_cluster: RedisClusterInfo) -> None: rconn = RedisConnectionInfo( - aioredis.sentinel.Sentinel( + redis.asyncio.sentinel.Sentinel( redis_cluster.sentinel_addrs, password='develove', socket_timeout=0.5, @@ -123,7 +123,7 @@ async def test_pipeline_sentinel_cluster(redis_cluster: RedisClusterInfo) -> Non service_name='mymaster', ) - def _build_pipeline(r: aioredis.Redis) -> aioredis.client.Pipeline: + def _build_pipeline(r: redis.asyncio.Redis) -> redis.asyncio.client.Pipeline: pipe = r.pipeline(transaction=False) pipe.set("xyz", "123") pipe.incr("xyz") diff --git a/tests/redis/test_pubsub.py b/tests/redis/test_pubsub.py index 31893be0..43bb9a2a 100644 --- a/tests/redis/test_pubsub.py +++ b/tests/redis/test_pubsub.py @@ -5,16 +5,16 @@ List, ) -import aioredis -import aioredis.client -import aioredis.exceptions +import redis.asyncio +import redis.asyncio.client +import redis.exceptions import aiotools import pytest from .docker import DockerRedisNode from .utils import interrupt -from ai.backend.common import redis +from ai.backend.common import redis_helper from ai.backend.common.types import RedisConnectionInfo @@ -29,10 +29,10 @@ async def test_pubsub(redis_container: str, disruption_method: str) -> None: unpaused = asyncio.Event() received_messages: List[str] = [] - async def subscribe(pubsub: aioredis.client.PubSub) -> None: + async def subscribe(pubsub: redis.asyncio.client.PubSub) -> None: try: async with aiotools.aclosing( - redis.subscribe(pubsub, reconnect_poll_interval=0.3), + redis_helper.subscribe(pubsub, reconnect_poll_interval=0.3), ) as agen: async for raw_msg in agen: msg = raw_msg.decode() @@ -41,10 +41,10 @@ async def subscribe(pubsub: aioredis.client.PubSub) -> None: pass r = RedisConnectionInfo( - aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + redis.asyncio.from_url(url='redis://localhost:9379', socket_timeout=0.5), service_name=None, ) - assert isinstance(r.client, aioredis.Redis) + assert isinstance(r.client, redis.asyncio.Redis) await r.client.delete("ch1") pubsub = r.client.pubsub() async with pubsub: @@ -69,7 +69,7 @@ async def subscribe(pubsub: aioredis.client.PubSub) -> None: for i in range(5): # The Redis server is dead temporarily... if disruption_method == 'stop': - with pytest.raises(aioredis.exceptions.ConnectionError): + with pytest.raises(redis.exceptions.ConnectionError): await r.client.publish("ch1", str(5 + i)) elif disruption_method == 'pause': with pytest.raises(asyncio.TimeoutError): @@ -111,10 +111,10 @@ async def test_pubsub_with_retrying_pub(redis_container: str, disruption_method: unpaused = asyncio.Event() received_messages: List[str] = [] - async def subscribe(pubsub: aioredis.client.PubSub) -> None: + async def subscribe(pubsub: redis.asyncio.client.PubSub) -> None: try: async with aiotools.aclosing( - redis.subscribe(pubsub, reconnect_poll_interval=0.3), + redis_helper.subscribe(pubsub, reconnect_poll_interval=0.3), ) as agen: async for raw_msg in agen: msg = raw_msg.decode() @@ -123,10 +123,10 @@ async def subscribe(pubsub: aioredis.client.PubSub) -> None: pass r = RedisConnectionInfo( - aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + redis.asyncio.from_url(url='redis://localhost:9379', socket_timeout=0.5), service_name=None, ) - assert isinstance(r.client, aioredis.Redis) + assert isinstance(r.client, redis.asyncio.Redis) await r.client.delete("ch1") pubsub = r.client.pubsub() async with pubsub: @@ -144,7 +144,7 @@ async def subscribe(pubsub: aioredis.client.PubSub) -> None: await asyncio.sleep(0) for i in range(5): - await redis.execute(r, lambda r: r.publish("ch1", str(i))) + await redis_helper.execute(r, lambda r: r.publish("ch1", str(i))) await asyncio.sleep(0.1) do_pause.set() await paused.wait() @@ -155,13 +155,13 @@ async def wakeup(): wakeup_task = asyncio.create_task(wakeup()) for i in range(5): - await redis.execute(r, lambda r: r.publish("ch1", str(5 + i))) + await redis_helper.execute(r, lambda r: r.publish("ch1", str(5 + i))) await asyncio.sleep(0.1) await wakeup_task await unpaused.wait() for i in range(5): - await redis.execute(r, lambda r: r.publish("ch1", str(10 + i))) + await redis_helper.execute(r, lambda r: r.publish("ch1", str(10 + i))) await asyncio.sleep(0.1) await interrupt_task @@ -198,7 +198,7 @@ async def interrupt() -> None: await asyncio.sleep(0.5) unpaused.set() - async def subscribe(pubsub: aioredis.client.PubSub) -> None: + async def subscribe(pubsub: redis.asyncio.client.PubSub) -> None: try: async with aiotools.aclosing( redis.subscribe(pubsub, reconnect_poll_interval=0.3) @@ -210,7 +210,7 @@ async def subscribe(pubsub: aioredis.client.PubSub) -> None: except asyncio.CancelledError: pass - s = aioredis.sentinel.Sentinel( + s = redis.asyncio.sentinel.Sentinel( redis_cluster.sentinel_addrs, password='develove', socket_timeout=0.5, diff --git a/tests/redis/test_stream.py b/tests/redis/test_stream.py index 484c12b1..de05233d 100644 --- a/tests/redis/test_stream.py +++ b/tests/redis/test_stream.py @@ -8,15 +8,14 @@ List, ) -import aioredis -import aioredis.client -import aioredis.exceptions -import aioredis.sentinel +import redis.asyncio +import redis.asyncio.client +import redis.exceptions import aiotools from aiotools.context import aclosing import pytest -from ai.backend.common import redis +from ai.backend.common import redis_helper from ai.backend.common.types import RedisConnectionInfo from .docker import DockerRedisNode @@ -43,7 +42,7 @@ async def consume( key: str, ) -> None: try: - async with aclosing(redis.read_stream(r, key)) as agen: + async with aclosing(redis_helper.read_stream(r, key)) as agen: async for msg_id, msg_data in agen: print(f"XREAD[{consumer_id}]", msg_id, repr(msg_data), file=sys.stderr) received_messages[consumer_id].append(msg_data[b"idx"]) @@ -54,11 +53,11 @@ async def consume( raise r = RedisConnectionInfo( - aioredis.from_url('redis://localhost:9379', socket_timeout=0.5), + redis.asyncio.from_url('redis://localhost:9379', socket_timeout=0.5), service_name=None, ) - assert isinstance(r.client, aioredis.Redis) - await redis.execute(r, lambda r: r.delete("stream1")) + assert isinstance(r.client, redis.asyncio.Redis) + await redis_helper.execute(r, lambda r: r.delete("stream1")) consumer_tasks = [ asyncio.create_task(consume("c1", r, "stream1")), @@ -85,7 +84,7 @@ async def consume( for i in range(5): # The Redis server is dead temporarily... if disruption_method == 'stop': - with pytest.raises(aioredis.exceptions.ConnectionError): + with pytest.raises(redis.exceptions.ConnectionError): await r.client.xadd("stream1", {"idx": 5 + i}) elif disruption_method == 'pause': with pytest.raises(asyncio.TimeoutError): @@ -136,7 +135,7 @@ async def consume( key: str, ) -> None: try: - async with aclosing(redis.read_stream(r, key)) as agen: + async with aclosing(redis_helper.read_stream(r, key)) as agen: async for msg_id, msg_data in agen: print(f"XREAD[{consumer_id}]", msg_id, repr(msg_data), file=sys.stderr) received_messages[consumer_id].append(msg_data[b"idx"]) @@ -147,14 +146,14 @@ async def consume( raise s = RedisConnectionInfo( - aioredis.sentinel.Sentinel( + redis.asyncio.Sentinel( redis_cluster.sentinel_addrs, password='develove', socket_timeout=0.5, ), service_name='mymaster', ) - _execute = aiotools.apartial(redis.execute, s) + _execute = aiotools.apartial(redis_helper.execute, s) await _execute(lambda r: r.delete("stream1")) consumer_tasks = [ @@ -227,7 +226,7 @@ async def consume( key: str, ) -> None: try: - async with aclosing(redis.read_stream_by_group( + async with aclosing(redis_helper.read_stream_by_group( r, key, group_name, consumer_id, autoclaim_idle_timeout=500, )) as agen: @@ -241,12 +240,12 @@ async def consume( return r = RedisConnectionInfo( - aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + redis.asyncio.from_url(url='redis://localhost:9379', socket_timeout=0.5), service_name=None, ) - assert isinstance(r.client, aioredis.Redis) - await redis.execute(r, lambda r: r.delete("stream1")) - await redis.execute(r, lambda r: r.xgroup_create("stream1", "group1", b"$", mkstream=True)) + assert isinstance(r.client, redis.asyncio.Redis) + await redis_helper.execute(r, lambda r: r.delete("stream1")) + await redis_helper.execute(r, lambda r: r.xgroup_create("stream1", "group1", b"$", mkstream=True)) consumer_tasks = [ asyncio.create_task(consume("group1", "c1", r, "stream1")), @@ -273,7 +272,7 @@ async def consume( for i in range(5): # The Redis server is dead temporarily... if disruption_method == 'stop': - with pytest.raises(aioredis.exceptions.ConnectionError): + with pytest.raises(redis.exceptions.ConnectionError): await r.client.xadd("stream1", {"idx": 5 + i}) elif disruption_method == 'pause': with pytest.raises(asyncio.TimeoutError): @@ -325,7 +324,7 @@ async def consume( key: str, ) -> None: try: - async with aclosing(redis.read_stream_by_group( + async with aclosing(redis_helper.read_stream_by_group( r, key, group_name, consumer_id, autoclaim_idle_timeout=500, )) as agen: @@ -339,14 +338,14 @@ async def consume( return s = RedisConnectionInfo( - aioredis.sentinel.Sentinel( + redis.asyncio.Sentinel( redis_cluster.sentinel_addrs, password='develove', socket_timeout=0.5, ), service_name='mymaster', ) - _execute = aiotools.apartial(redis.execute, s) + _execute = aiotools.apartial(redis_helper.execute, s) await _execute(lambda r: r.delete("stream1")) await _execute(lambda r: r.xgroup_create("stream1", "group1", b"$", mkstream=True)) diff --git a/tests/redis/utils.py b/tests/redis/utils.py index 54b4394b..a0b6d6f2 100644 --- a/tests/redis/utils.py +++ b/tests/redis/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations -import aioredis -import aioredis.exceptions +import redis.asyncio +import redis.exceptions import async_timeout import asyncio import functools @@ -41,20 +41,20 @@ async def simple_run_cmd(cmdargs: Sequence[Union[str, bytes]], **kwargs) -> asyn async def wait_redis_ready(host: str, port: int, password: str = None) -> None: - r = aioredis.from_url(f"redis://{host}:{port}", password=password, socket_timeout=0.2) + r = redis.asyncio.from_url(f"redis://{host}:{port}", password=password, socket_timeout=0.2) while True: try: print("CheckReady.PING", port, file=sys.stderr) await r.ping() print("CheckReady.PONG", port, file=sys.stderr) - except aioredis.exceptions.AuthenticationError: + except redis.exceptions.AuthenticationError: raise except ( ConnectionResetError, - aioredis.exceptions.ConnectionError, + redis.exceptions.ConnectionError, ): await asyncio.sleep(0.1) - except aioredis.exceptions.TimeoutError: + except redis.exceptions.TimeoutError: pass else: break diff --git a/tests/test_events.py b/tests/test_events.py index ef9672b3..101a6ce3 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -18,7 +18,7 @@ EtcdRedisConfig, HostPortPair, ) -from ai.backend.common import redis +from ai.backend.common import redis_helper @attr.s(slots=True, frozen=True) @@ -71,7 +71,7 @@ def scb(context: object, source: AgentId, event: DummyEvent) -> None: await asyncio.sleep(0.2) assert records == {'async', 'sync'} - await redis.execute(producer.redis_client, lambda r: r.flushdb()) + await redis_helper.execute(producer.redis_client, lambda r: r.flushdb()) await producer.close() await dispatcher.close() @@ -118,7 +118,7 @@ def scb(context: object, source: AgentId, event: DummyEvent) -> None: assert 'ZeroDivisionError' in exception_log assert 'OverflowError' in exception_log - await redis.execute(producer.redis_client, lambda r: r.flushdb()) + await redis_helper.execute(producer.redis_client, lambda r: r.flushdb()) await producer.close() await dispatcher.close() From 46a25d26aff274debaa1d168c743bcc43d06b690 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Mon, 4 Apr 2022 04:53:43 +0000 Subject: [PATCH 2/3] docs: Add news fragment --- changes/134.fix.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/134.fix.md diff --git a/changes/134.fix.md b/changes/134.fix.md new file mode 100644 index 00000000..eb67911e --- /dev/null +++ b/changes/134.fix.md @@ -0,0 +1 @@ +Migrate from aioredis to redis-py v4.2.1 in favor of the official support From 308f299478443ff6beb488307caf1d4ecfa1ee54 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Mon, 4 Apr 2022 04:58:00 +0000 Subject: [PATCH 3/3] fix: Update common.events --- src/ai/backend/common/events.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ai/backend/common/events.py b/src/ai/backend/common/events.py index b92bc253..678ae68d 100644 --- a/src/ai/backend/common/events.py +++ b/src/ai/backend/common/events.py @@ -28,9 +28,9 @@ from typing_extensions import TypeAlias import uuid -import aioredis -import aioredis.exceptions -import aioredis.sentinel +import redis.asyncio +import redis.asyncio.sentinel +import redis.exceptions from aiotools.context import aclosing from aiotools.server import process_index from aiotools.taskgroup import PersistentTaskGroup @@ -522,7 +522,7 @@ class BgtaskFailedEvent(BgtaskDoneEventArgs, AbstractEvent): class RedisConnectorFunc(Protocol): def __call__( self, - ) -> aioredis.ConnectionPool: + ) -> redis.asyncio.ConnectionPool: ...