Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

Commit

Permalink
Merge pull request #624 from NAFTeam/dev
Browse files Browse the repository at this point in the history
NAFF 1.10.0
  • Loading branch information
LordOfPolls authored Sep 5, 2022
2 parents 0081ecb + 5bcc33b commit b7f675a
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 42 deletions.
2 changes: 1 addition & 1 deletion naff/api/events/processors/guild_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def _on_raw_guild_create(self, event: "RawGatewayEvent") -> None:

self._guild_event.set()

if self.fetch_members: # noqa
if self.fetch_members and not guild.chunked.is_set(): # noqa
# delays events until chunking has completed
await guild.chunk()

Expand Down
16 changes: 9 additions & 7 deletions naff/api/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TypeVar, TYPE_CHECKING

from naff.api import events
from naff.client.const import logger, MISSING
from naff.client.const import logger, MISSING, __api_version__
from naff.client.utils.input_utils import OverriddenJson
from naff.client.utils.serializer import dict_filter_none
from naff.models.discord.enums import Status
Expand Down Expand Up @@ -48,7 +48,6 @@ class GatewayClient(WebsocketClient):
Multiple `WebsocketClient` instances can be used to implement same-process sharding.
Attributes:
buffer: A buffer to hold incoming data until its complete
sequence: The sequence of this connection
session_id: The session ID of this connection
Expand Down Expand Up @@ -83,7 +82,7 @@ def __init__(self, state: "ConnectionState", shard: tuple[int, int]) -> None:
self._ready = asyncio.Event()
self._close_gateway = asyncio.Event()

# Santity check, it is extremely important that an instance isn't reused.
# Sanity check, it is extremely important that an instance isn't reused.
self._entered = False

async def __aenter__(self: SELF) -> SELF:
Expand Down Expand Up @@ -177,6 +176,7 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:
match op:

case OPCODE.HEARTBEAT:
logger.debug("Received heartbeat request from gateway")
return await self.send_heartbeat()

case OPCODE.HEARTBEAT_ACK:
Expand All @@ -192,12 +192,12 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:
return self._acknowledged.set()

case OPCODE.RECONNECT:
logger.info("Gateway requested reconnect. Reconnecting...")
logger.debug("Gateway requested reconnect. Reconnecting...")
return await self.reconnect(resume=True, url=self.ws_resume_url)

case OPCODE.INVALIDATE_SESSION:
logger.warning("Gateway has invalidated session! Reconnecting...")
return await self.reconnect(resume=data, url=self.ws_resume_url if data else None)
return await self.reconnect()

case _:
return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}")
Expand All @@ -209,7 +209,9 @@ async def dispatch_event(self, data, seq, event) -> None:
self._trace = data.get("_trace", [])
self.sequence = seq
self.session_id = data["session_id"]
self.ws_resume_url = data["resume_gateway_url"]
self.ws_resume_url = (
f"{data['resume_gateway_url']}?encoding=json&v={__api_version__}&compress=zlib-stream"
)
logger.info(f"Shard {self.shard[0]} has connected to gateway!")
logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}")
# todo: future polls, improve guild caching here. run the debugger. you'll see why
Expand Down Expand Up @@ -287,7 +289,7 @@ async def _resume_connection(self) -> None:
logger.debug(f"{self.shard[0]} is attempting to resume a connection")

async def send_heartbeat(self) -> None:
await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, True)
await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, bypass=True)
logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat")

async def change_presence(self, activity=None, status: Status = Status.ONLINE, since=None) -> None:
Expand Down
13 changes: 7 additions & 6 deletions naff/api/gateway/websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import collections
import random
import time
import zlib
from abc import abstractmethod
Expand Down Expand Up @@ -275,12 +276,12 @@ async def _start_bee_gees(self) -> None:
if self.heartbeat_interval is None:
raise RuntimeError

# try:
# await asyncio.wait_for(self._kill_bee_gees.wait(), timeout=self.heartbeat_interval * random.uniform(0, 0.5))
# except asyncio.TimeoutError:
# pass
# else:
# return
try:
await asyncio.wait_for(self._kill_bee_gees.wait(), timeout=self.heartbeat_interval * random.uniform(0, 0.5))
except asyncio.TimeoutError:
pass
else:
return

logger.debug(f"Sending heartbeat every {self.heartbeat_interval} seconds")
while not self._kill_bee_gees.is_set():
Expand Down
5 changes: 5 additions & 0 deletions naff/client/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def __init__(
self.text = data
super().__init__(f"{self.status}|{self.response.reason}: {f'({self.code}) ' if self.code else ''}{self.text}")

def __str__(self) -> str:
errors = self.search_for_message(self.errors)
out = f"HTTPException: {self.status}|{self.response.reason}: " + "\n".join(errors)
return out

@staticmethod
def search_for_message(errors: dict, lookup: Optional[dict] = None) -> list[str]:
"""
Expand Down
3 changes: 2 additions & 1 deletion naff/client/smart_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def delete_member(self, guild_id: "Snowflake_Type", user_id: "Snowflake_Type") -
guild_id = to_snowflake(guild_id)

if member := self.member_cache.pop((guild_id, user_id), None):
member.guild._member_ids.discard(user_id)
if member.guild:
member.guild._member_ids.discard(user_id)

self.delete_user_guild(user_id, guild_id)

Expand Down
8 changes: 8 additions & 0 deletions naff/models/discord/auto_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ class KeywordPresetTrigger(BaseTrigger):
)


@define()
class MentionSpamTrigger(BaseTrigger):
"""A trigger that checks if content contains more mentions than allowed"""

mention_total_limit: int = field(default=3, repr=True, metadata=docs("The maximum number of mentions allowed"))


@define()
class BlockMessage(BaseAction):
"""blocks the content of a message according to the rule"""
Expand Down Expand Up @@ -320,4 +327,5 @@ def message(self) -> "Optional[Message]":
AutoModTriggerType.KEYWORD: KeywordTrigger,
AutoModTriggerType.HARMFUL_LINK: HarmfulLinkFilter,
AutoModTriggerType.KEYWORD_PRESET: KeywordPresetTrigger,
AutoModTriggerType.MENTION_SPAM: MentionSpamTrigger,
}
7 changes: 4 additions & 3 deletions naff/models/discord/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,20 +865,21 @@ class AuditLogEventType(CursedIntEnum):
GUILD_HOME_FEATURE_ITEM_UPDATE = 172


class AutoModTriggerType(IntEnum):
class AutoModTriggerType(CursedIntEnum):
KEYWORD = 1
HARMFUL_LINK = 2
SPAM = 3
KEYWORD_PRESET = 4
MENTION_SPAM = 5


class AutoModAction(IntEnum):
class AutoModAction(CursedIntEnum):
BLOCK_MESSAGE = 1
ALERT_MESSAGE = 2
TIMEOUT_USER = 3


class AutoModEvent(IntEnum):
class AutoModEvent(CursedIntEnum):
MESSAGE_SEND = 1


Expand Down
51 changes: 28 additions & 23 deletions naff/models/discord/guild.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import time
from asyncio import QueueEmpty
from collections import namedtuple
from functools import cmp_to_key
from typing import List, Optional, Union, Set, Dict, Any, TYPE_CHECKING
Expand Down Expand Up @@ -119,6 +120,26 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]
return super()._process_dict(data, client)


class MemberIterator(AsyncIterator):
def __init__(self, guild: "Guild", limit: int = 0) -> None:
super().__init__(limit)
self.guild = guild
self._more = True

async def fetch(self) -> list:
if self._more:
expected = self.get_limit

rcv = await self.guild._client.http.list_members(
self.guild.id, limit=expected, after=self.last["id"] if self.last else MISSING
)
if not rcv:
raise QueueEmpty
self._more = len(rcv) == expected
return rcv
raise QueueEmpty


@define()
class Guild(BaseGuild):
"""Guilds in Discord represent an isolated collection of users and channels, and are often referred to as "servers" in the UI."""
Expand Down Expand Up @@ -501,31 +522,15 @@ async def edit_nickname(self, new_nickname: Absent[str] = MISSING, reason: Absen
async def http_chunk(self) -> None:
"""Populates all members of this guild using the REST API."""
start_time = time.perf_counter()
members = []

# request all guild members
after = MISSING
while True:
if members:
after = members[-1]["user"]["id"]
rcv: list = await self._client.http.list_members(self.id, limit=1000, after=after)
members.extend(rcv)
if len(rcv) < 1000:
# we're done
break

# process all members
s = time.monotonic()
for member in members:

iterator = MemberIterator(self)
async for member in iterator:
self._client.cache.place_member_data(self.id, member)
if (time.monotonic() - s) > 0.05:
# look, i get this *could* be a thread, but because it needs to modify data in the main thread,
# it is still blocking. So by periodically yielding to the event loop, we can avoid blocking, and still
# process this data properly
await asyncio.sleep(0)
s = time.monotonic()

self.chunked.set()
logger.info(f"Cached {len(members)} members for {self.id} in {time.perf_counter() - start_time:.2f} seconds")
logger.info(
f"Cached {iterator.total_retrieved} members for {self.id} in {time.perf_counter() - start_time:.2f} seconds"
)

async def gateway_chunk(self, wait=True, presences=True) -> None:
"""
Expand Down
5 changes: 5 additions & 0 deletions naff/models/misc/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def get_limit(self) -> int:
"""Get how the maximum number of items that should be retrieved."""
return min(self._limit - len(self._retrieved_objects), 100) if self._limit else 100

@property
def total_retrieved(self) -> int:
"""Get the total number of objects this iterator has retrieved."""
return len(self._retrieved_objects)

async def add_object(self, obj) -> None:
"""Add an object to iterator's queue."""
return await self._queue.put(obj)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "naff"
version = "1.9.0"
version = "1.10.0"
description = "Not another freaking fork"
authors = ["LordOfPolls <[email protected]>"]

Expand Down

0 comments on commit b7f675a

Please sign in to comment.