Skip to content

Commit

Permalink
Fix timer handle churn in websocket heartbeat (#8608)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam Bull <[email protected]>
  • Loading branch information
bdraco and Dreamsorcerer authored Aug 7, 2024
1 parent b2691f2 commit c4acabc
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 88 deletions.
3 changes: 3 additions & 0 deletions CHANGES/8608.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Improved websocket performance when messages are sent or received frequently -- by :user:`bdraco`.

The WebSocket heartbeat scheduling algorithm was improved to reduce the ``asyncio`` scheduling overhead by decreasing the number of ``asyncio.TimerHandle`` creations and cancellations.
115 changes: 72 additions & 43 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .client_exceptions import ClientError, ServerTimeoutError
from .client_reqrep import ClientResponse
from .helpers import call_later, set_result
from .helpers import calculate_timeout_when, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
self._heartbeat_when: float = 0.0
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
Expand All @@ -85,52 +86,64 @@ def __init__(
self._reset_heartbeat()

def _cancel_heartbeat(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None

self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None

def _reset_heartbeat(self) -> None:
self._cancel_heartbeat()
def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None

if self._heartbeat is not None:
self._heartbeat_cb = call_later(
self._send_heartbeat,
self._heartbeat,
self._loop,
timeout_ceil_threshold=(
self._conn._connector._timeout_ceil_threshold
if self._conn is not None
else 5
),
)
def _reset_heartbeat(self) -> None:
if self._heartbeat is None:
return
self._cancel_pong_response_cb()
loop = self._loop
assert loop is not None
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
now = loop.time()
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
self._heartbeat_when = when
if self._heartbeat_cb is None:
# We do not cancel the previous heartbeat_cb here because
# it generates a significant amount of TimerHandle churn
# which causes asyncio to rebuild the heap frequently.
# Instead _send_heartbeat() will reschedule the next
# heartbeat if it fires too early.
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)

def _send_heartbeat(self) -> None:
if self._heartbeat is not None and not self._closed:
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = call_later(
self._pong_not_received,
self._pong_heartbeat,
self._loop,
timeout_ceil_threshold=(
self._conn._connector._timeout_ceil_threshold
if self._conn is not None
else 5
),
self._heartbeat_cb = None
loop = self._loop
now = loop.time()
if now < self._heartbeat_when:
# Heartbeat fired too early, reschedule
self._heartbeat_cb = loop.call_at(
self._heartbeat_when, self._send_heartbeat
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

def _pong_not_received(self) -> None:
if not self._closed:
self._closed = True
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = ServerTimeoutError()
self._response.close()
Expand All @@ -139,6 +152,22 @@ def _pong_not_received(self) -> None:
WSMessage(WSMsgType.ERROR, self._exception, None)
)

def _set_closed(self) -> None:
"""Set the connection to closed.
Cancel any heartbeat timers and set the closed flag.
"""
self._closed = True
self._cancel_heartbeat()

def _set_closing(self) -> None:
"""Set the connection to closing.
Cancel any heartbeat timers and set the closing flag.
"""
self._closing = True
self._cancel_heartbeat()

@property
def closed(self) -> bool:
return self._closed
Expand Down Expand Up @@ -203,13 +232,12 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
if self._waiting and not self._closing:
assert self._loop is not None
self._close_wait = self._loop.create_future()
self._closing = True
self._set_closing()
self._reader.feed_data(WS_CLOSING_MESSAGE)
await self._close_wait

if not self._closed:
self._cancel_heartbeat()
self._closed = True
self._set_closed()
try:
await self._writer.close(code, message)
except asyncio.CancelledError:
Expand Down Expand Up @@ -278,7 +306,8 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except ClientError:
self._closed = True
# Likely ServerDisconnectedError when connection is lost
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
Expand All @@ -287,19 +316,19 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._set_closing()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)

if msg.type is WSMsgType.CLOSE:
self._closing = True
self._set_closing()
self._close_code = msg.data
# Could be closed elsewhere while awaiting reader
if not self._closed and self._autoclose: # type: ignore[redundant-expr]
await self.close()
elif msg.type is WSMsgType.CLOSING:
self._closing = True
self._set_closing()
elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
Expand Down
23 changes: 17 additions & 6 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,12 +598,23 @@ def call_later(
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
if timeout is not None and timeout > 0:
when = loop.time() + timeout
if timeout > timeout_ceil_threshold:
when = ceil(when)
return loop.call_at(when, cb)
return None
if timeout is None or timeout <= 0:
return None
now = loop.time()
when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
return loop.call_at(when, cb)


def calculate_timeout_when(
loop_time: float,
timeout: float,
timeout_ceiling_threshold: float,
) -> float:
"""Calculate when to execute a timeout."""
when = loop_time + timeout
if timeout > timeout_ceiling_threshold:
return ceil(when)
return when


class TimeoutHandle:
Expand Down
101 changes: 62 additions & 39 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import call_later, set_exception, set_result
from .helpers import calculate_timeout_when, set_exception, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
Expand Down Expand Up @@ -74,6 +74,7 @@ class WebSocketResponse(StreamResponse):
"_autoclose",
"_autoping",
"_heartbeat",
"_heartbeat_when",
"_heartbeat_cb",
"_pong_heartbeat",
"_pong_response_cb",
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_when = 0.0
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
Expand All @@ -120,57 +122,76 @@ def __init__(
self._max_msg_size = max_msg_size

def _cancel_heartbeat(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None

self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None

def _reset_heartbeat(self) -> None:
self._cancel_heartbeat()
def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None

if self._heartbeat is not None:
assert self._loop is not None
self._heartbeat_cb = call_later(
self._send_heartbeat,
self._heartbeat,
self._loop,
timeout_ceil_threshold=(
self._req._protocol._timeout_ceil_threshold
if self._req is not None
else 5
),
)
def _reset_heartbeat(self) -> None:
if self._heartbeat is None:
return
self._cancel_pong_response_cb()
req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
)
loop = self._loop
assert loop is not None
now = loop.time()
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
self._heartbeat_when = when
if self._heartbeat_cb is None:
# We do not cancel the previous heartbeat_cb here because
# it generates a significant amount of TimerHandle churn
# which causes asyncio to rebuild the heap frequently.
# Instead _send_heartbeat() will reschedule the next
# heartbeat if it fires too early.
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)

def _send_heartbeat(self) -> None:
if self._heartbeat is not None and not self._closed:
assert self._loop is not None and self._writer is not None
# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
self._loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = call_later(
self._pong_not_received,
self._pong_heartbeat,
self._loop,
timeout_ceil_threshold=(
self._req._protocol._timeout_ceil_threshold
if self._req is not None
else 5
),
self._heartbeat_cb = None
loop = self._loop
assert loop is not None and self._writer is not None
now = loop.time()
if now < self._heartbeat_when:
# Heartbeat fired too early, reschedule
self._heartbeat_cb = loop.call_at(
self._heartbeat_when, self._send_heartbeat
)
return

# fire-and-forget a task is not perfect but maybe ok for
# sending ping. Otherwise we need a long-living heartbeat
# task in the class.
loop.create_task(self._writer.ping()) # type: ignore[unused-awaitable]

req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
)
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)

def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._closed = True
self._set_closed()
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = asyncio.TimeoutError()

def _set_closed(self) -> None:
"""Set the connection to closed.
Cancel any heartbeat timers and set the closed flag.
"""
self._closed = True
self._cancel_heartbeat()

async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
# make pre-check to don't hide it by do_handshake() exceptions
if self._payload_writer is not None:
Expand Down Expand Up @@ -410,7 +431,7 @@ async def close(
if self._closed:
return False

self._closed = True
self._set_closed()
try:
await self._writer.close(code, message)
writer = self._payload_writer
Expand Down Expand Up @@ -454,6 +475,7 @@ def _set_closing(self, code: WSCloseCode) -> None:
"""Set the close code and mark the connection as closing."""
self._closing = True
self._close_code = code
self._cancel_heartbeat()

def _set_code_close_transport(self, code: WSCloseCode) -> None:
"""Set the close code and close the transport."""
Expand Down Expand Up @@ -566,5 +588,6 @@ def _cancel(self, exc: BaseException) -> None:
# web_protocol calls this from connection_lost
# or when the server is shutting down.
self._closing = True
self._cancel_heartbeat()
if self._reader is not None:
set_exception(self._reader, exc)
Loading

0 comments on commit c4acabc

Please sign in to comment.