From 49e419dce642b7aef12daa3fa8495f6b9beb1fbe Mon Sep 17 00:00:00 2001 From: Chase Bennett <33002121+qeternity@users.noreply.github.com> Date: Sun, 25 Jul 2021 18:24:16 +0100 Subject: [PATCH] pubsub event loop layer proxy (#262) --- channels_redis/pubsub.py | 64 +++++++++++++++++++++++++++++++++++ tests/test_pubsub.py | 35 +++++++++++++++++++ tests/test_pubsub_sentinel.py | 35 +++++++++++++++++++ tests/test_sentinel.py | 4 +-- 4 files changed, 135 insertions(+), 3 deletions(-) diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index 2acc9bb..62907d5 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -1,5 +1,8 @@ import asyncio +import functools import logging +import sys +import types import uuid import aioredis @@ -8,7 +11,68 @@ logger = logging.getLogger(__name__) +if sys.version_info >= (3, 7): + get_running_loop = asyncio.get_running_loop +else: + get_running_loop = asyncio.get_event_loop + + +def _wrap_close(proxy, loop): + original_impl = loop.close + + def _wrapper(self, *args, **kwargs): + if loop in proxy._layers: + layer = proxy._layers[loop] + del proxy._layers[loop] + loop.run_until_complete(layer.flush()) + + self.close = original_impl + return self.close(*args, **kwargs) + + loop.close = types.MethodType(_wrapper, loop) + + class RedisPubSubChannelLayer: + def __init__(self, *args, **kwargs) -> None: + self._args = args + self._kwargs = kwargs + self._layers = {} + + def __getattr__(self, name): + if name in ( + "new_channel", + "send", + "receive", + "group_add", + "group_discard", + "group_send", + "flush", + ): + return functools.partial(self._proxy, name) + else: + return getattr(self._get_layer(), name) + + def _get_layer(self): + loop = get_running_loop() + + try: + layer = self._layers[loop] + except KeyError: + layer = RedisPubSubLoopLayer(*self._args, **self._kwargs) + self._layers[loop] = layer + _wrap_close(self, loop) + + return layer + + def _proxy(self, name, *args, **kwargs): + async def coro(): + layer = self._get_layer() + return await getattr(layer, name)(*args, **kwargs) + + return coro() + + +class RedisPubSubLoopLayer: """ Channel Layer that uses Redis's pub/sub functionality. """ diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index f79b9d5..f890367 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -5,6 +5,7 @@ import pytest from async_generator import async_generator, yield_ +from asgiref.sync import async_to_sync from channels_redis.pubsub import RedisPubSubChannelLayer TEST_HOSTS = [("localhost", 6379)] @@ -33,6 +34,18 @@ async def test_send_receive(channel_layer): assert message["text"] == "Ahoy-hoy!" +@pytest.mark.asyncio +def test_send_receive_sync(channel_layer, event_loop): + _await = event_loop.run_until_complete + channel = _await(channel_layer.new_channel()) + async_to_sync(channel_layer.send, force_new_loop=True)( + channel, {"type": "test.message", "text": "Ahoy-hoy!"} + ) + message = _await(channel_layer.receive(channel)) + assert message["type"] == "test.message" + assert message["text"] == "Ahoy-hoy!" + + @pytest.mark.asyncio async def test_multi_send_receive(channel_layer): """ @@ -47,6 +60,19 @@ async def test_multi_send_receive(channel_layer): assert (await channel_layer.receive(channel))["type"] == "message.3" +@pytest.mark.asyncio +def test_multi_send_receive_sync(channel_layer, event_loop): + _await = event_loop.run_until_complete + channel = _await(channel_layer.new_channel()) + send = async_to_sync(channel_layer.send) + send(channel, {"type": "message.1"}) + send(channel, {"type": "message.2"}) + send(channel, {"type": "message.3"}) + assert _await(channel_layer.receive(channel))["type"] == "message.1" + assert _await(channel_layer.receive(channel))["type"] == "message.2" + assert _await(channel_layer.receive(channel))["type"] == "message.3" + + @pytest.mark.asyncio async def test_groups_basic(channel_layer): """ @@ -101,3 +127,12 @@ async def test_random_reset__channel_name(channel_layer): channel_name_2 = await channel_layer.new_channel() assert channel_name_1 != channel_name_2 + + +def test_multi_event_loop_garbage_collection(channel_layer): + """ + Test loop closure layer flushing and garbage collection + """ + assert len(channel_layer._layers.values()) == 0 + async_to_sync(test_send_receive)(channel_layer) + assert len(channel_layer._layers.values()) == 0 diff --git a/tests/test_pubsub_sentinel.py b/tests/test_pubsub_sentinel.py index 233cfcf..053764c 100644 --- a/tests/test_pubsub_sentinel.py +++ b/tests/test_pubsub_sentinel.py @@ -5,6 +5,7 @@ import pytest from async_generator import async_generator, yield_ +from asgiref.sync import async_to_sync from channels_redis.pubsub import RedisPubSubChannelLayer SENTINEL_MASTER = "sentinel" @@ -34,6 +35,18 @@ async def test_send_receive(channel_layer): assert message["text"] == "Ahoy-hoy!" +@pytest.mark.asyncio +def test_send_receive_sync(channel_layer, event_loop): + _await = event_loop.run_until_complete + channel = _await(channel_layer.new_channel()) + async_to_sync(channel_layer.send, force_new_loop=True)( + channel, {"type": "test.message", "text": "Ahoy-hoy!"} + ) + message = _await(channel_layer.receive(channel)) + assert message["type"] == "test.message" + assert message["text"] == "Ahoy-hoy!" + + @pytest.mark.asyncio async def test_multi_send_receive(channel_layer): """ @@ -48,6 +61,19 @@ async def test_multi_send_receive(channel_layer): assert (await channel_layer.receive(channel))["type"] == "message.3" +@pytest.mark.asyncio +def test_multi_send_receive_sync(channel_layer, event_loop): + _await = event_loop.run_until_complete + channel = _await(channel_layer.new_channel()) + send = async_to_sync(channel_layer.send) + send(channel, {"type": "message.1"}) + send(channel, {"type": "message.2"}) + send(channel, {"type": "message.3"}) + assert _await(channel_layer.receive(channel))["type"] == "message.1" + assert _await(channel_layer.receive(channel))["type"] == "message.2" + assert _await(channel_layer.receive(channel))["type"] == "message.3" + + @pytest.mark.asyncio async def test_groups_basic(channel_layer): """ @@ -102,3 +128,12 @@ async def test_random_reset__channel_name(channel_layer): channel_name_2 = await channel_layer.new_channel() assert channel_name_1 != channel_name_2 + + +def test_multi_event_loop_garbage_collection(channel_layer): + """ + Test loop closure layer flushing and garbage collection + """ + assert len(channel_layer._layers.values()) == 0 + async_to_sync(test_send_receive)(channel_layer) + assert len(channel_layer._layers.values()) == 0 diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index eb22763..3c69c5e 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -62,9 +62,7 @@ async def channel_layer(): Channel layer fixture that flushes automatically. """ channel_layer = RedisChannelLayer( - hosts=TEST_HOSTS, - capacity=3, - channel_capacity={"tiny": 1}, + hosts=TEST_HOSTS, capacity=3, channel_capacity={"tiny": 1} ) await yield_(channel_layer) await channel_layer.flush()