Skip to content

Commit

Permalink
pubsub event loop layer proxy (django#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
qeternity authored Jul 25, 2021
1 parent 9793d09 commit 49e419d
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 3 deletions.
64 changes: 64 additions & 0 deletions channels_redis/pubsub.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import functools
import logging
import sys
import types
import uuid

import aioredis
Expand All @@ -8,7 +11,68 @@
logger = logging.getLogger(__name__)


if sys.version_info >= (3, 7):
get_running_loop = asyncio.get_running_loop
else:
get_running_loop = asyncio.get_event_loop


def _wrap_close(proxy, loop):
original_impl = loop.close

def _wrapper(self, *args, **kwargs):
if loop in proxy._layers:
layer = proxy._layers[loop]
del proxy._layers[loop]
loop.run_until_complete(layer.flush())

self.close = original_impl
return self.close(*args, **kwargs)

loop.close = types.MethodType(_wrapper, loop)


class RedisPubSubChannelLayer:
def __init__(self, *args, **kwargs) -> None:
self._args = args
self._kwargs = kwargs
self._layers = {}

def __getattr__(self, name):
if name in (
"new_channel",
"send",
"receive",
"group_add",
"group_discard",
"group_send",
"flush",
):
return functools.partial(self._proxy, name)
else:
return getattr(self._get_layer(), name)

def _get_layer(self):
loop = get_running_loop()

try:
layer = self._layers[loop]
except KeyError:
layer = RedisPubSubLoopLayer(*self._args, **self._kwargs)
self._layers[loop] = layer
_wrap_close(self, loop)

return layer

def _proxy(self, name, *args, **kwargs):
async def coro():
layer = self._get_layer()
return await getattr(layer, name)(*args, **kwargs)

return coro()


class RedisPubSubLoopLayer:
"""
Channel Layer that uses Redis's pub/sub functionality.
"""
Expand Down
35 changes: 35 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from async_generator import async_generator, yield_

from asgiref.sync import async_to_sync
from channels_redis.pubsub import RedisPubSubChannelLayer

TEST_HOSTS = [("localhost", 6379)]
Expand Down Expand Up @@ -33,6 +34,18 @@ async def test_send_receive(channel_layer):
assert message["text"] == "Ahoy-hoy!"


@pytest.mark.asyncio
def test_send_receive_sync(channel_layer, event_loop):
_await = event_loop.run_until_complete
channel = _await(channel_layer.new_channel())
async_to_sync(channel_layer.send, force_new_loop=True)(
channel, {"type": "test.message", "text": "Ahoy-hoy!"}
)
message = _await(channel_layer.receive(channel))
assert message["type"] == "test.message"
assert message["text"] == "Ahoy-hoy!"


@pytest.mark.asyncio
async def test_multi_send_receive(channel_layer):
"""
Expand All @@ -47,6 +60,19 @@ async def test_multi_send_receive(channel_layer):
assert (await channel_layer.receive(channel))["type"] == "message.3"


@pytest.mark.asyncio
def test_multi_send_receive_sync(channel_layer, event_loop):
_await = event_loop.run_until_complete
channel = _await(channel_layer.new_channel())
send = async_to_sync(channel_layer.send)
send(channel, {"type": "message.1"})
send(channel, {"type": "message.2"})
send(channel, {"type": "message.3"})
assert _await(channel_layer.receive(channel))["type"] == "message.1"
assert _await(channel_layer.receive(channel))["type"] == "message.2"
assert _await(channel_layer.receive(channel))["type"] == "message.3"


@pytest.mark.asyncio
async def test_groups_basic(channel_layer):
"""
Expand Down Expand Up @@ -101,3 +127,12 @@ async def test_random_reset__channel_name(channel_layer):
channel_name_2 = await channel_layer.new_channel()

assert channel_name_1 != channel_name_2


def test_multi_event_loop_garbage_collection(channel_layer):
"""
Test loop closure layer flushing and garbage collection
"""
assert len(channel_layer._layers.values()) == 0
async_to_sync(test_send_receive)(channel_layer)
assert len(channel_layer._layers.values()) == 0
35 changes: 35 additions & 0 deletions tests/test_pubsub_sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from async_generator import async_generator, yield_

from asgiref.sync import async_to_sync
from channels_redis.pubsub import RedisPubSubChannelLayer

SENTINEL_MASTER = "sentinel"
Expand Down Expand Up @@ -34,6 +35,18 @@ async def test_send_receive(channel_layer):
assert message["text"] == "Ahoy-hoy!"


@pytest.mark.asyncio
def test_send_receive_sync(channel_layer, event_loop):
_await = event_loop.run_until_complete
channel = _await(channel_layer.new_channel())
async_to_sync(channel_layer.send, force_new_loop=True)(
channel, {"type": "test.message", "text": "Ahoy-hoy!"}
)
message = _await(channel_layer.receive(channel))
assert message["type"] == "test.message"
assert message["text"] == "Ahoy-hoy!"


@pytest.mark.asyncio
async def test_multi_send_receive(channel_layer):
"""
Expand All @@ -48,6 +61,19 @@ async def test_multi_send_receive(channel_layer):
assert (await channel_layer.receive(channel))["type"] == "message.3"


@pytest.mark.asyncio
def test_multi_send_receive_sync(channel_layer, event_loop):
_await = event_loop.run_until_complete
channel = _await(channel_layer.new_channel())
send = async_to_sync(channel_layer.send)
send(channel, {"type": "message.1"})
send(channel, {"type": "message.2"})
send(channel, {"type": "message.3"})
assert _await(channel_layer.receive(channel))["type"] == "message.1"
assert _await(channel_layer.receive(channel))["type"] == "message.2"
assert _await(channel_layer.receive(channel))["type"] == "message.3"


@pytest.mark.asyncio
async def test_groups_basic(channel_layer):
"""
Expand Down Expand Up @@ -102,3 +128,12 @@ async def test_random_reset__channel_name(channel_layer):
channel_name_2 = await channel_layer.new_channel()

assert channel_name_1 != channel_name_2


def test_multi_event_loop_garbage_collection(channel_layer):
"""
Test loop closure layer flushing and garbage collection
"""
assert len(channel_layer._layers.values()) == 0
async_to_sync(test_send_receive)(channel_layer)
assert len(channel_layer._layers.values()) == 0
4 changes: 1 addition & 3 deletions tests/test_sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ async def channel_layer():
Channel layer fixture that flushes automatically.
"""
channel_layer = RedisChannelLayer(
hosts=TEST_HOSTS,
capacity=3,
channel_capacity={"tiny": 1},
hosts=TEST_HOSTS, capacity=3, channel_capacity={"tiny": 1}
)
await yield_(channel_layer)
await channel_layer.flush()
Expand Down

0 comments on commit 49e419d

Please sign in to comment.