diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index 4f8cc387a9150..f0a25f4202d8f 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -72,17 +72,15 @@ def _cancel_timer_if_pending(self): def _start_timer(self): self._cancel_timer_if_pending() self._timer_task = asyncio.create_task( - self._sleep_then_set(self._flush_trigger) + self._sleep_then_set() ) self._timer_started = True - async def _sleep_then_set(self, event: Event): + async def _sleep_then_set(self): """Sleep and then set the event - - :param event: event to set """ await asyncio.sleep(self._timeout / 1000) - event.set() + self._flush_trigger.set() async def push(self, request: DataRequest) -> asyncio.Queue: """Append request to the the list of requests to be processed. @@ -220,7 +218,6 @@ def batch(iterable_1, iterable_2, n=1): # communicate that the request has been processed properly. At this stage the data_lock is ours and # therefore noone can add requests to this list. self._flush_trigger: Event = Event() - self._cancel_timer_if_pending() self._timer_task = None try: if not docarray_v2: @@ -274,7 +271,7 @@ def batch(iterable_1, iterable_2, n=1): await request_full.put(exc) else: # We need to attribute the docs to their requests - non_assigned_to_response_docs.extend(batch_res_docs) + non_assigned_to_response_docs.extend(batch_res_docs or docs_inner_batch) non_assigned_to_response_request_idxs.extend(req_idxs) num_assigned_docs = await _assign_results( non_assigned_to_response_docs, diff --git a/tests/integration/deployments/test_deployment.py b/tests/integration/deployments/test_deployment.py index e4a1c5185faba..f28ef122223e6 100644 --- a/tests/integration/deployments/test_deployment.py +++ b/tests/integration/deployments/test_deployment.py @@ -458,12 +458,14 @@ def serve_depl(stop_event, **kwargs): t.join() +@pytest.mark.repeat(10) @pytest.mark.parametrize('served_depl', [False, True], indirect=True) def test_deployment_dynamic_batching(served_depl, exposed_port): docs = Client(port=exposed_port).post(on='/bar', inputs=DocumentArray.empty(5)) assert docs.texts == ['bar' for _ in docs] +@pytest.mark.repeat(10) @pytest.mark.parametrize('enable_dynamic_batching', [False, True]) def test_deployment_client_dynamic_batching(enable_dynamic_batching): kwargs = {'port': random_port()} diff --git a/tests/integration/docarray_v2/test_v2.py b/tests/integration/docarray_v2/test_v2.py index 273a60eb8f10b..da82bb52479e4 100644 --- a/tests/integration/docarray_v2/test_v2.py +++ b/tests/integration/docarray_v2/test_v2.py @@ -811,7 +811,6 @@ def foo(self, **kwargs): def test_custom_gateway(): from docarray import DocList - from docarray.documents.text import TextDoc from jina import Executor, Flow, requests from jina.serve.runtimes.gateway.http import FastAPIBaseGateway @@ -1612,3 +1611,36 @@ def generate( return_type=DocList[MyRandomModel], ) assert res[0].a == 'hey' + + +@pytest.mark.repeat(10) +def test_exception_handling_in_dynamic_batch(): + from jina.proto import jina_pb2 + class DummyEmbeddingDoc(BaseDoc): + lf: List[float] = [] + + class SlowExecutorWithException(Executor): + + @dynamic_batching(preferred_batch_size=3, timeout=1000) + @requests(on='/foo') + def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[DummyEmbeddingDoc]: + ret = DocList[DummyEmbeddingDoc]() + for doc in docs: + if doc.text == 'fail': + raise Exception('Fail is in the Batch') + ret.append(DummyEmbeddingDoc(lf=[0.1, 0.2, 0.3])) + return ret + + depl = Deployment(uses=SlowExecutorWithException) + + with depl: + da = DocList[TextDoc]([TextDoc(text='good') for _ in range(50)]) + da[4].text = 'fail' + responses = depl.post(on='/foo', inputs=da, request_size=1, return_responses=True, continue_on_error=True, results_in_order=True) + assert len(responses) == 50 # 1 request per input + num_failed_requests = 0 + for r in responses: + if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR: + num_failed_requests += 1 + + assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing \ No newline at end of file diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 3e174162c2894..b21769e09fddd 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -14,6 +14,7 @@ DocumentArray, Executor, Flow, + Deployment, dynamic_batching, requests, ) @@ -21,6 +22,7 @@ from jina.serve.networking.utils import send_request_sync from jina.serve.runtimes.servers import BaseServer from jina_cli.api import executor_native +from jina.proto import jina_pb2 from tests.helper import _generate_pod_args cur_dir = os.path.dirname(__file__) @@ -311,6 +313,7 @@ def test_preferred_batch_size(add_parameters, use_stream): assert time_taken < TIMEOUT_TOLERANCE +@pytest.mark.repeat(10) @pytest.mark.parametrize('use_stream', [False, True]) def test_correctness(use_stream): f = Flow().add(uses=PlaceholderExecutor) @@ -492,6 +495,7 @@ def test_param_correctness(use_stream): ] assert [doc.text for doc in results[2]] == [f'D{str(PARAM1)}'] + @pytest.mark.parametrize( 'uses', [ @@ -622,3 +626,29 @@ def test_failure_propagation(): '/wronglennone', inputs=DocumentArray([Document(text=str(i)) for i in range(8)]), ) + + +@pytest.mark.repeat(10) +def test_exception_handling_in_dynamic_batch(): + class SlowExecutorWithException(Executor): + + @dynamic_batching(preferred_batch_size=3, timeout=1000) + @requests(on='/foo') + def foo(self, docs, **kwargs): + for doc in docs: + if doc.text == 'fail': + raise Exception('Fail is in the Batch') + + depl = Deployment(uses=SlowExecutorWithException) + + with depl: + da = DocumentArray([Document(text='good') for _ in range(50)]) + da[4].text = 'fail' + responses = depl.post(on='/foo', inputs=da, request_size=1, return_responses=True, continue_on_error=True, results_in_order=True) + assert len(responses) == 50 # 1 request per input + num_failed_requests = 0 + for r in responses: + if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR: + num_failed_requests += 1 + + assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing