diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index d09cb4bdcdc7..f23646263964 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -4,7 +4,8 @@ import sys import time from glob import glob -from unittest.mock import patch +from unittest.mock import patch, MagicMock +from itertools import chain import grpc import pytest @@ -398,6 +399,80 @@ def make_internal_kv_calls(): make_internal_kv_calls() +@pytest.mark.skipif( + sys.platform == "win32", reason="PSUtil does not work the same on windows." +) +def test_proxy_cancelled_grpc_request_stream(): + """ + Test that DataServicerProxy and LogstreamServicerProxy should gracefully + close grpc stream when the request stream is cancelled. + """ + + proxier.CHECK_PROCESS_INTERVAL_S = 1 + # The timeout has likely been set to 1 in an earlier test. Increase timeout + # to wait for the channel to become ready. + proxier.CHECK_CHANNEL_TIMEOUT_S = 5 + os.environ["TIMEOUT_FOR_SPECIFIC_SERVER_S"] = "5" + pm, free_ports = start_ray_and_proxy_manager(n_ports=2) + + data_servicer = proxier.DataServicerProxy(pm) + logstream_servicer = proxier.LogstreamServicerProxy(pm) + + # simulate cancelled grpc request stream + # https://github.com/grpc/grpc/blob/v1.43.0/src/python/grpcio/grpc/_server.py#L353-L354 + class Cancelled: + def __iter__(self): + return self + + def __next__(self): + raise grpc.RpcError() + + context = MagicMock() + context.set_code = MagicMock() + context.set_details = MagicMock() + context.invocation_metadata = MagicMock( + return_value=[ + ("client_id", "client1"), + ("reconnecting", "False"), + ] + ) + + init = ray_client_pb2.DataRequest( + req_id=1, + init=ray_client_pb2.InitRequest(job_config=pickle.dumps(JobConfig())), + ) + + for _ in data_servicer.Datapath(chain([init], Cancelled()), context): + pass + for _ in logstream_servicer.Logstream(Cancelled(), context): + pass + + assert not context.set_code.called, "grpc error should not be set" + assert not context.set_details.called, "grpc error should not be set" + + class Rendezvous: + def __iter__(self): + return self + + def __next__(self): + raise grpc._Rendezvous() + + context.invocation_metadata = MagicMock( + return_value=[ + ("client_id", "client2"), + ("reconnecting", "False"), + ] + ) + + for _ in data_servicer.Datapath(chain([init], Rendezvous()), context): + pass + for _ in logstream_servicer.Logstream(Rendezvous(), context): + pass + + assert context.set_code.called, "grpc error should be set" + assert context.set_details.called, "grpc error should be set" + + if __name__ == "__main__": import sys diff --git a/python/ray/util/client/server/proxier.py b/python/ray/util/client/server/proxier.py index 691271708c1e..3e176c4d977c 100644 --- a/python/ray/util/client/server/proxier.py +++ b/python/ray/util/client/server/proxier.py @@ -608,6 +608,30 @@ def prepare_runtime_init_req( return (init_request, new_job_config) +class RequestIteratorProxy: + def __init__(self, request_iterator): + self.request_iterator = request_iterator + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self.request_iterator) + except grpc.RpcError as e: + # To stop proxying already CANCLLED request stream gracefully, + # we only translate the exact grpc.RpcError to StopIteration, + # not its subsclasses. ex: grpc._Rendezvous + # https://github.com/grpc/grpc/blob/v1.43.0/src/python/grpcio/grpc/_server.py#L353-L354 + # This fixes the https://github.com/ray-project/ray/issues/23865 + if type(e) != grpc.RpcError: + raise e # re-raise other grpc exceptions + logger.exception( + "Stop iterating cancelled request stream with the following exception:" + ) + raise StopIteration + + class DataServicerProxy(ray_client_pb2_grpc.RayletDataStreamerServicer): def __init__(self, proxy_manager: ProxyManager): self.num_clients = 0 @@ -635,6 +659,7 @@ def modify_connection_info_resp( return modified_resp def Datapath(self, request_iterator, context): + request_iterator = RequestIteratorProxy(request_iterator) cleanup_requested = False start_time = time.time() client_id = _get_client_id_from_context(context) @@ -761,6 +786,7 @@ def __init__(self, proxy_manager: ProxyManager): self.proxy_manager = proxy_manager def Logstream(self, request_iterator, context): + request_iterator = RequestIteratorProxy(request_iterator) client_id = _get_client_id_from_context(context) if client_id == "": return