Skip to content

Commit

Permalink
[serve] Prevent ASGI receive() from hanging on disconnects (ray-pro…
Browse files Browse the repository at this point in the history
…ject#44647)

See ray-project#44644 for the user-reported issue.

In ray-project#43713, I changed the behavior of the ASGI receive interface to not raise an exception but rather silently hang if the request was not present in the proxy. This was done under the assumption that the code calling receive would be subsequently cancelled anyways. However, this assumption was faulty for a few reasons:

- The user may ignore the CancelledError but still want to fetch ASGI messages.
- Cancellation is best-effort. Specifically, because we are currently using the run_coroutine_threadsafe / concurrent.futures.Executor interface, there is no way to guarantee that either the scheduling task is cancelled or the resulting object ref is cancelled. In rare cases, neither may happen (see asyncio.run_coroutine_threadsafe leaves underlying cancelled asyncio task running python/cpython#105836).

This PR mitigates the issue by returning an appropriate disconnect ASGI message when the proxy no longer has the request in scope.

Note that this is not a perfect solution: we may still drop the messages prior to the disconnect (e.g., the body of the HTTP request) if the replica does not fetch them prior to the client disconnecting.
- Note that this is not a regression from the current behavior and should only happen in practice if the client disconnects very quickly after initiating the request.
- As a follow-up, we should rework this codepath to something more robust. For example we could use a push-based model where the proxy instead pushes messages eagerly to the replica (this was originally very difficult to implement but should be possible now).

---------

Signed-off-by: Edward Oakes <[email protected]>
  • Loading branch information
edoakes authored Apr 11, 2024
1 parent 7364705 commit 905ca61
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
25 changes: 24 additions & 1 deletion python/ray/serve/_private/http_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,34 @@ class ASGIReceiveProxy:

def __init__(
self,
scope: Scope,
request_id: str,
receive_asgi_messages: Callable[[str], Awaitable[bytes]],
):
self._type = scope["type"] # Either 'http' or 'websocket'.
self._queue = asyncio.Queue()
self._request_id = request_id
self._receive_asgi_messages = receive_asgi_messages
self._disconnect_message = None

def _get_default_disconnect_message(self) -> Message:
"""Return the appropriate disconnect message based on the connection type.
HTTP ASGI spec:
https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event
WS ASGI spec:
https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event-ws
"""
if self._type == "websocket":
return {
"type": "websocket.disconnect",
# 1005 is the default disconnect code according to the ASGI spec.
"code": 1005,
}
else:
return {"type": "http.disconnect"}

async def fetch_until_disconnect(self):
"""Fetch messages repeatedly until a disconnect message is received.
Expand All @@ -244,8 +264,11 @@ async def fetch_until_disconnect(self):
return
except KeyError:
# KeyError can be raised if the request is no longer active in the proxy
# (e.g., the user disconnects). This is expected behavior and we should
# (i.e., the user disconnects). This is expected behavior and we should
# not log an error: https://github.com/ray-project/ray/issues/43290.
message = self._get_default_disconnect_message()
self._queue.put_nowait(message)
self._disconnect_message = message
return
except Exception as e:
# Raise unexpected exceptions in the next `__call__`.
Expand Down
6 changes: 5 additions & 1 deletion python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ def _wrap_user_method_call(self, request_metadata: RequestMetadata):
try:
self._metrics_manager.inc_num_ongoing_requests()
yield
except asyncio.CancelledError as e:
user_exception = e
except Exception as e:
user_exception = e
logger.error(f"Request failed:\n{e}")
Expand Down Expand Up @@ -942,15 +944,17 @@ def _prepare_args_for_http_request(
The returned `receive_task` should be cancelled when the user method exits.
"""
scope = pickle.loads(request.pickled_asgi_scope)
receive = ASGIReceiveProxy(
scope,
request_metadata.request_id,
request.receive_asgi_messages,
)
receive_task = self._user_code_event_loop.create_task(
receive.fetch_until_disconnect()
)
asgi_args = ASGIArgs(
scope=pickle.loads(request.pickled_asgi_scope),
scope=scope,
receive=receive,
send=generator_result_callback,
)
Expand Down
43 changes: 31 additions & 12 deletions python/ray/serve/tests/unit/test_http_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import pickle
import sys
from typing import Generator, Tuple

import pytest
Expand Down Expand Up @@ -65,9 +66,12 @@ async def test_message_queue():

@pytest.fixture
@pytest.mark.asyncio
def setup_receive_proxy() -> Generator[
Tuple[ASGIReceiveProxy, MessageQueue], None, None
]:
def setup_receive_proxy(
request,
) -> Generator[Tuple[ASGIReceiveProxy, MessageQueue], None, None]:
# Param can be 'http' (default) or 'websocket' (ASGI scope type).
type = getattr(request, "param", "http")

queue = MessageQueue()

async def receive_asgi_messages(request_id: str) -> bytes:
Expand All @@ -80,7 +84,7 @@ async def receive_asgi_messages(request_id: str) -> bytes:
return pickle.dumps(messages)

loop = get_or_create_event_loop()
asgi_receive_proxy = ASGIReceiveProxy("", receive_asgi_messages)
asgi_receive_proxy = ASGIReceiveProxy({"type": type}, "", receive_asgi_messages)
receiver_task = loop.create_task(asgi_receive_proxy.fetch_until_disconnect())
try:
yield asgi_receive_proxy, queue
Expand Down Expand Up @@ -125,9 +129,20 @@ async def test_raises_exception(
with pytest.raises(RuntimeError, match="oopsies"):
await asgi_receive_proxy()

async def test_does_not_raise_key_error(
@pytest.mark.parametrize(
"setup_receive_proxy",
["http", "websocket"],
indirect=True,
)
async def test_return_disconnect_on_key_error(
self, setup_receive_proxy: Tuple[ASGIReceiveProxy, MessageQueue]
):
"""If the proxy is no longer handling a given request, it raises a KeyError.
In these cases, the ASGI receive proxy should return a disconnect message.
See https://github.com/ray-project/ray/pull/44647 for details.
"""
asgi_receive_proxy, queue = setup_receive_proxy

queue.put_nowait({"type": "foo"})
Expand All @@ -136,17 +151,23 @@ async def test_does_not_raise_key_error(
assert await asgi_receive_proxy() == {"type": "bar"}

queue.put_nowait(KeyError("not found"))
_, pending = await asyncio.wait(
[asyncio.create_task(asgi_receive_proxy())], timeout=0.01
)
assert len(pending) == 1
for _ in range(100):
if asgi_receive_proxy._type == "http":
assert await asgi_receive_proxy() == {"type": "http.disconnect"}
else:
assert await asgi_receive_proxy() == {
"type": "websocket.disconnect",
"code": 1005,
}

async def test_receive_asgi_messages_raises(self):
async def receive_asgi_messages(request_id: str) -> bytes:
raise RuntimeError("maybe actor crashed")

loop = get_or_create_event_loop()
asgi_receive_proxy = ASGIReceiveProxy("", receive_asgi_messages)
asgi_receive_proxy = ASGIReceiveProxy(
{"type": "http"}, "", receive_asgi_messages
)
receiver_task = loop.create_task(asgi_receive_proxy.fetch_until_disconnect())

try:
Expand All @@ -157,6 +178,4 @@ async def receive_asgi_messages(request_id: str) -> bytes:


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", "-s", __file__]))

0 comments on commit 905ca61

Please sign in to comment.