Skip to content

Commit

Permalink
[Core] Stop iteratoring cancelled grpc request streams (ray-project#2…
Browse files Browse the repository at this point in the history
…3865)

Signed-off-by: Rueian <[email protected]>
  • Loading branch information
rueian committed Oct 12, 2022
1 parent 223ce39 commit 39cf221
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 1 deletion.
77 changes: 76 additions & 1 deletion python/ray/tests/test_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import sys
import time
from glob import glob
from unittest.mock import patch
from unittest.mock import patch, MagicMock
from itertools import chain

import grpc
import pytest
Expand Down Expand Up @@ -398,6 +399,80 @@ def make_internal_kv_calls():
make_internal_kv_calls()


@pytest.mark.skipif(
sys.platform == "win32", reason="PSUtil does not work the same on windows."
)
def test_proxy_cancelled_grpc_request_stream():
"""
Test that DataServicerProxy and LogstreamServicerProxy should gracefully
close grpc stream when the request stream is cancelled.
"""

proxier.CHECK_PROCESS_INTERVAL_S = 1
# The timeout has likely been set to 1 in an earlier test. Increase timeout
# to wait for the channel to become ready.
proxier.CHECK_CHANNEL_TIMEOUT_S = 5
os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5"
pm, free_ports = start_ray_and_proxy_manager(n_ports=2)

data_servicer = proxier.DataServicerProxy(pm)
logstream_servicer = proxier.LogstreamServicerProxy(pm)

# simulate cancelled grpc request stream
# https://github.com/grpc/grpc/blob/v1.43.0/src/python/grpcio/grpc/_server.py#L353-L354
class Cancelled:
def __iter__(self):
return self

def __next__(self):
raise grpc.RpcError()

context = MagicMock()
context.set_code = MagicMock()
context.set_details = MagicMock()
context.invocation_metadata = MagicMock(
return_value=[
("client_id", "client1"),
("reconnecting", "False"),
]
)

init = ray_client_pb2.DataRequest(
req_id=1,
init=ray_client_pb2.InitRequest(job_config=pickle.dumps(JobConfig())),
)

for _ in data_servicer.Datapath(chain([init], Cancelled()), context):
pass
for _ in logstream_servicer.Logstream(Cancelled(), context):
pass

assert not context.set_code.called, "grpc error should not be set"
assert not context.set_details.called, "grpc error should not be set"

class Rendezvous:
def __iter__(self):
return self

def __next__(self):
raise grpc._Rendezvous()

context.invocation_metadata = MagicMock(
return_value=[
("client_id", "client2"),
("reconnecting", "False"),
]
)

for _ in data_servicer.Datapath(chain([init], Rendezvous()), context):
pass
for _ in logstream_servicer.Logstream(Rendezvous(), context):
pass

assert context.set_code.called, "grpc error should be set"
assert context.set_details.called, "grpc error should be set"


if __name__ == "__main__":
import sys

Expand Down
26 changes: 26 additions & 0 deletions python/ray/util/client/server/proxier.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,30 @@ def prepare_runtime_init_req(
return (init_request, new_job_config)


class RequestIteratorProxy:
def __init__(self, request_iterator):
self.request_iterator = request_iterator

def __iter__(self):
return self

def __next__(self):
try:
return next(self.request_iterator)
except grpc.RpcError as e:
# To stop proxying already CANCLLED request stream gracefully,
# we only translate the exact grpc.RpcError to StopIteration,
# not its subsclasses. ex: grpc._Rendezvous
# https://github.com/grpc/grpc/blob/v1.43.0/src/python/grpcio/grpc/_server.py#L353-L354
# This fixes the https://github.com/ray-project/ray/issues/23865
if type(e) != grpc.RpcError:
raise e # re-raise other grpc exceptions
logger.exception(
"Stop iterating cancelled request stream with the following exception:"
)
raise StopIteration


class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, proxy_manager: ProxyManager):
self.num_clients = 0
Expand Down Expand Up @@ -635,6 +659,7 @@ def modify_connection_info_resp(
return modified_resp

def Datapath(self, request_iterator, context):
request_iterator = RequestIteratorProxy(request_iterator)
cleanup_requested = False
start_time = time.time()
client_id = _get_client_id_from_context(context)
Expand Down Expand Up @@ -761,6 +786,7 @@ def __init__(self, proxy_manager: ProxyManager):
self.proxy_manager = proxy_manager

def Logstream(self, request_iterator, context):
request_iterator = RequestIteratorProxy(request_iterator)
client_id = _get_client_id_from_context(context)
if client_id == "":
return
Expand Down

0 comments on commit 39cf221

Please sign in to comment.