Skip to content

Commit

Permalink
Refactor WebSocketWriter to remove high level protocol functions
Browse files Browse the repository at this point in the history
ping and pong are details that are implemented in ``ClientWebSocketResponse`` and
``WebSocketResponse``. Keep the writer clean by limiting it to sending frames
and closing

closes #2837
  • Loading branch information
bdraco committed Oct 28, 2024
1 parent 68b3378 commit d3c3fb8
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 25 deletions.
17 changes: 9 additions & 8 deletions aiohttp/_websocket/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@


class WebSocketWriter:
"""WebSocket writer.
The writer is responsible for sending messages to the client. It is
created by the protocol when a connection is established. The writer
should avoid implementing any application logic and should only be
concerned with the low-level details of the WebSocket protocol.
"""

def __init__(
self,
protocol: BaseProtocol,
Expand All @@ -45,6 +53,7 @@ def __init__(
compress: int = 0,
notakeover: bool = False,
) -> None:
"""Initialize a WebSocket writer."""
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
Expand Down Expand Up @@ -155,14 +164,6 @@ def _make_compress_obj(self, compress: int) -> ZLibCompressor:
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)

async def pong(self, message: bytes = b"") -> None:
"""Send pong message."""
await self.send_frame(message, WSMsgType.PONG)

async def ping(self, message: bytes = b"") -> None:
"""Send ping message."""
await self.send_frame(message, WSMsgType.PING)

async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
Expand Down
9 changes: 5 additions & 4 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

coro = self._writer.send_frame(b"", WSMsgType.PING)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
ping_task = loop.create_task(self._writer.ping())
ping_task = loop.create_task(coro)

if not ping_task.done():
self._ping_task = ping_task
Expand Down Expand Up @@ -225,10 +226,10 @@ def exception(self) -> Optional[BaseException]:
return self._exception

async def ping(self, message: bytes = b"") -> None:
await self._writer.ping(message)
await self._writer.send_frame(message, WSMsgType.PING)

async def pong(self, message: bytes = b"") -> None:
await self._writer.pong(message)
await self._writer.send_frame(message, WSMsgType.PONG)

async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
Expand Down
9 changes: 5 additions & 4 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,14 @@ def _send_heartbeat(self) -> None:
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

coro = self._writer.send_frame(b"", WSMsgType.PING)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(self._writer.ping(), loop=loop, eager_start=True)
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
ping_task = loop.create_task(self._writer.ping())
ping_task = loop.create_task(coro)

if not ping_task.done():
self._ping_task = ping_task
Expand Down Expand Up @@ -397,13 +398,13 @@ def exception(self) -> Optional[BaseException]:
async def ping(self, message: bytes = b"") -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
await self._writer.ping(message)
await self._writer.send_frame(message, WSMsgType.PING)

async def pong(self, message: bytes = b"") -> None:
# unsolicited pong
if self._writer is None:
raise RuntimeError("Call .prepare() first")
await self._writer.pong(message)
await self._writer.send_frame(message, WSMsgType.PONG)

async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
Expand Down
15 changes: 11 additions & 4 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,11 @@ async def handler(request: web.Request) -> NoReturn:
assert resp._conn is not None
with mock.patch.object(
resp._conn.transport, "write", side_effect=ClientConnectionResetError
), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping:
), mock.patch.object(
resp._writer, "send_frame", wraps=resp._writer.send_frame
) as send_frame:
await resp.receive()
ping_count = ping.call_count
ping_count = send_frame.call_args_list.count(mock.call(b"", WSMsgType.PING))
# Connection should be closed roughly after 1.5x heartbeat.
await asyncio.sleep(0.2)
assert ping_count == 1
Expand Down Expand Up @@ -871,8 +873,13 @@ async def handler(request: web.Request) -> NoReturn:

cancelled = False
ping_stated = False
original_send_frame = resp._writer.send_frame

async def delayed_ping() -> None:
async def delayed_send_frame(
message: bytes, opcode: int, compress: Optional[int] = None
) -> None:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note test

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
if opcode != WSMsgType.PING:
return await original_send_frame(message, opcode, compress)
nonlocal cancelled, ping_stated
ping_stated = True
try:
Expand All @@ -881,7 +888,7 @@ async def delayed_ping() -> None:
cancelled = True
raise

with mock.patch.object(resp._writer, "ping", delayed_ping):
with mock.patch.object(resp._writer, "send_frame", delayed_send_frame):
await asyncio.sleep(0.1)

await resp.close()
Expand Down
8 changes: 5 additions & 3 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,12 +785,14 @@ async def handler(request: web.Request) -> NoReturn:
with mock.patch.object(
ws_server._req.transport, "write", side_effect=ConnectionResetError
), mock.patch.object(
ws_server._writer, "ping", wraps=ws_server._writer.ping
) as ping:
ws_server._writer, "send_frame", wraps=ws_server._writer.send_frame
) as send_frame:
try:
await ws_server.receive()
finally:
ping_count = ping.call_count
ping_count = send_frame.call_args_list.count(
mock.call(b"", WSMsgType.PING)
)
assert False

app = web.Application()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def writer(protocol: BaseProtocol, transport: asyncio.Transport) -> WebSocketWri


async def test_pong(writer: WebSocketWriter) -> None:
await writer.pong()
await writer.send_frame(b"", WSMsgType.PONG)
writer.transport.write.assert_called_with(b"\x8a\x00") # type: ignore[attr-defined]


async def test_ping(writer: WebSocketWriter) -> None:
await writer.ping()
await writer.send_frame(b"", WSMsgType.PING)
writer.transport.write.assert_called_with(b"\x89\x00") # type: ignore[attr-defined]


Expand Down

0 comments on commit d3c3fb8

Please sign in to comment.