Skip to content

Commit

Permalink
Fix #273 -- Use consistent hashing for PubSub (#274)
Browse files Browse the repository at this point in the history
* Fix #273 -- Use consistent hashing for PubSub

* Update pubsub.py

* Add missing import

* Refactor hash function into utils module and add test

* Run black
  • Loading branch information
raphaelm authored Sep 6, 2021
1 parent 49e419d commit f5eef16
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 15 deletions.
13 changes: 3 additions & 10 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import base64
import binascii
import collections
import functools
import hashlib
Expand All @@ -18,6 +17,8 @@
from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer

from .utils import _consistent_hash

logger = logging.getLogger(__name__)

AIOREDIS_VERSION = tuple(map(int, aioredis.__version__.split(".")))
Expand Down Expand Up @@ -858,15 +859,7 @@ def deserialize(self, message):
### Internal functions ###

def consistent_hash(self, value):
"""
Maps the value to a node value between 0 and 4095
using CRC, then down to one of the ring nodes.
"""
if isinstance(value, str):
value = value.encode("utf8")
bigval = binascii.crc32(value) & 0xFFF
ring_divisor = 4096 / float(self.ring_size)
return int(bigval / ring_divisor)
return _consistent_hash(value, self.ring_size)

def make_fernet(self, key):
"""
Expand Down
8 changes: 3 additions & 5 deletions channels_redis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import aioredis
import msgpack

from .utils import _consistent_hash

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -106,11 +108,7 @@ def _get_shard(self, channel_or_group_name):
"""
Return the shard that is used exclusively for this channel or group.
"""
if len(self._shards) == 1:
# Avoid the overhead of hashing and modulo when it is unnecessary.
return self._shards[0]
shard_index = abs(hash(channel_or_group_name)) % len(self._shards)
return self._shards[shard_index]
return self._shards[_consistent_hash(channel_or_group_name, len(self._shards))]

def _get_group_channel_name(self, group):
"""
Expand Down
17 changes: 17 additions & 0 deletions channels_redis/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import binascii


def _consistent_hash(value, ring_size):
"""
Maps the value to a node value between 0 and 4095
using CRC, then down to one of the ring nodes.
"""
if ring_size == 1:
# Avoid the overhead of hashing and modulo when it is unnecessary.
return 0

if isinstance(value, str):
value = value.encode("utf8")
bigval = binascii.crc32(value) & 0xFFF
ring_divisor = 4096 / float(ring_size)
return int(bigval / ring_divisor)
20 changes: 20 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from channels_redis.utils import _consistent_hash


@pytest.mark.parametrize(
"value,ring_size,expected",
[
("key_one", 1, 0),
("key_two", 1, 0),
("key_one", 2, 1),
("key_two", 2, 0),
("key_one", 10, 6),
("key_two", 10, 4),
(b"key_one", 10, 6),
(b"key_two", 10, 4),
],
)
def test_consistent_hash_result(value, ring_size, expected):
assert _consistent_hash(value, ring_size) == expected

0 comments on commit f5eef16

Please sign in to comment.