Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning for the length of the group name #2122

Merged
merged 12 commits into from
Jan 28, 2025
14 changes: 9 additions & 5 deletions channels/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ def valid_channel_name(self, name, receive=False):
return True
raise TypeError(self.invalid_name_error.format("Channel", name))

def valid_group_name(self, name):
def require_valid_group_name(self, name):
if len(name) >= self.MAX_NAME_LENGTH:
raise TypeError(
f"Group name must be less than {self.MAX_NAME_LENGTH} characters."
)
if self.match_type_and_length(name):
if bool(self.group_name_regex.match(name)):
return True
Expand Down Expand Up @@ -341,16 +345,16 @@ async def group_add(self, group, channel):
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
self.require_valid_group_name(group)
self.valid_channel_name(channel), "Channel name not valid"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename valid_channel_name to require_valid_channel_name to match?

And also remove the , "Channel name not valid" since it isn't used?

# Add to group dict
self.groups.setdefault(group, {})
self.groups[group][channel] = time.time()

async def group_discard(self, group, channel):
# Both should be text and valid
assert self.valid_channel_name(channel), "Invalid channel name"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here as aboev

assert self.valid_group_name(group), "Invalid group name"
self.require_valid_group_name(group)
# Remove from group set
group_channels = self.groups.get(group, None)
if group_channels:
Expand All @@ -363,7 +367,7 @@ async def group_discard(self, group, channel):
async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
assert self.valid_group_name(group), "Invalid group name"
self.require_valid_group_name(group)
# Run clean
self._clean_expired()

Expand Down
26 changes: 25 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ async def test_send_receive():

@pytest.mark.parametrize(
"method",
[BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name],
[
BaseChannelLayer().valid_channel_name,
BaseChannelLayer().require_valid_group_name,
],
)
@pytest.mark.parametrize(
"channel_name,expected_valid",
Expand All @@ -84,3 +87,24 @@ def test_channel_and_group_name_validation(method, channel_name, expected_valid)
else:
with pytest.raises(TypeError):
method(channel_name)


@pytest.mark.parametrize(
"name, expected_error_message",
[
(
"a" * 101,
f"Group name must be less than {BaseChannelLayer.MAX_NAME_LENGTH} "
"characters.",
), # Group name too long
],
)
def test_group_name_length_error_message(name, expected_error_message):
"""
Ensure the correct error message is raised when group names
exceed the character limit.
"""
layer = BaseChannelLayer()

with pytest.raises(TypeError, match=expected_error_message):
layer.require_valid_group_name(name)
Loading