Skip to content

Commit

Permalink
fix: do not assign results if Exception (#6128)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Dec 14, 2023
1 parent 129b7b3 commit 48e4437
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 33 deletions.
54 changes: 22 additions & 32 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,15 @@ def _cancel_timer_if_pending(self):
def _start_timer(self):
self._cancel_timer_if_pending()
self._timer_task = asyncio.create_task(
self._sleep_then_set(self._flush_trigger)
self._sleep_then_set()
)
self._timer_started = True

async def _sleep_then_set(self, event: Event):
async def _sleep_then_set(self):
"""Sleep and then set the event
:param event: event to set
"""
await asyncio.sleep(self._timeout / 1000)
event.set()
self._flush_trigger.set()

async def push(self, request: DataRequest) -> asyncio.Queue:
"""Append request to the the list of requests to be processed.
Expand Down Expand Up @@ -220,13 +218,13 @@ def batch(iterable_1, iterable_2, n=1):
# communicate that the request has been processed properly. At this stage the data_lock is ours and
# therefore noone can add requests to this list.
self._flush_trigger: Event = Event()
self._cancel_timer_if_pending()
self._timer_task = None
try:
if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
else:
non_assigned_to_response_docs = self._response_docarray_cls()

non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
Expand Down Expand Up @@ -271,33 +269,25 @@ def batch(iterable_1, iterable_2, n=1):
involved_requests_min_indx : involved_requests_max_indx + 1
]:
await request_full.put(exc)
pass

# If there has been an exception, this will be docs_inner_batch
output_executor_docs = (
batch_res_docs
if batch_res_docs is not None
else docs_inner_batch
)

# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(output_executor_docs)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
)
else:
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(batch_res_docs or docs_inner_batch)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
if len(non_assigned_to_response_request_idxs) > 0:
_ = await _assign_results(
non_assigned_to_response_docs,
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/deployments/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,14 @@ def serve_depl(stop_event, **kwargs):
t.join()


@pytest.mark.repeat(10)
@pytest.mark.parametrize('served_depl', [False, True], indirect=True)
def test_deployment_dynamic_batching(served_depl, exposed_port):
docs = Client(port=exposed_port).post(on='/bar', inputs=DocumentArray.empty(5))
assert docs.texts == ['bar' for _ in docs]


@pytest.mark.repeat(10)
@pytest.mark.parametrize('enable_dynamic_batching', [False, True])
def test_deployment_client_dynamic_batching(enable_dynamic_batching):
kwargs = {'port': random_port()}
Expand Down
34 changes: 33 additions & 1 deletion tests/integration/docarray_v2/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,6 @@ def foo(self, **kwargs):

def test_custom_gateway():
from docarray import DocList
from docarray.documents.text import TextDoc

from jina import Executor, Flow, requests
from jina.serve.runtimes.gateway.http import FastAPIBaseGateway
Expand Down Expand Up @@ -1612,3 +1611,36 @@ def generate(
return_type=DocList[MyRandomModel],
)
assert res[0].a == 'hey'


@pytest.mark.repeat(10)
def test_exception_handling_in_dynamic_batch():
from jina.proto import jina_pb2
class DummyEmbeddingDoc(BaseDoc):
lf: List[float] = []

class SlowExecutorWithException(Executor):

@dynamic_batching(preferred_batch_size=3, timeout=1000)
@requests(on='/foo')
def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[DummyEmbeddingDoc]:
ret = DocList[DummyEmbeddingDoc]()
for doc in docs:
if doc.text == 'fail':
raise Exception('Fail is in the Batch')
ret.append(DummyEmbeddingDoc(lf=[0.1, 0.2, 0.3]))
return ret

depl = Deployment(uses=SlowExecutorWithException)

with depl:
da = DocList[TextDoc]([TextDoc(text='good') for _ in range(50)])
da[4].text = 'fail'
responses = depl.post(on='/foo', inputs=da, request_size=1, return_responses=True, continue_on_error=True, results_in_order=True)
assert len(responses) == 50 # 1 request per input
num_failed_requests = 0
for r in responses:
if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR:
num_failed_requests += 1

assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing
30 changes: 30 additions & 0 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
DocumentArray,
Executor,
Flow,
Deployment,
dynamic_batching,
requests,
)
from jina.clients.request import request_generator
from jina.serve.networking.utils import send_request_sync
from jina.serve.runtimes.servers import BaseServer
from jina_cli.api import executor_native
from jina.proto import jina_pb2
from tests.helper import _generate_pod_args

cur_dir = os.path.dirname(__file__)
Expand Down Expand Up @@ -311,6 +313,7 @@ def test_preferred_batch_size(add_parameters, use_stream):
assert time_taken < TIMEOUT_TOLERANCE


@pytest.mark.repeat(10)
@pytest.mark.parametrize('use_stream', [False, True])
def test_correctness(use_stream):
f = Flow().add(uses=PlaceholderExecutor)
Expand Down Expand Up @@ -492,6 +495,7 @@ def test_param_correctness(use_stream):
]
assert [doc.text for doc in results[2]] == [f'D{str(PARAM1)}']


@pytest.mark.parametrize(
'uses',
[
Expand Down Expand Up @@ -622,3 +626,29 @@ def test_failure_propagation():
'/wronglennone',
inputs=DocumentArray([Document(text=str(i)) for i in range(8)]),
)


@pytest.mark.repeat(10)
def test_exception_handling_in_dynamic_batch():
class SlowExecutorWithException(Executor):

@dynamic_batching(preferred_batch_size=3, timeout=1000)
@requests(on='/foo')
def foo(self, docs, **kwargs):
for doc in docs:
if doc.text == 'fail':
raise Exception('Fail is in the Batch')

depl = Deployment(uses=SlowExecutorWithException)

with depl:
da = DocumentArray([Document(text='good') for _ in range(50)])
da[4].text = 'fail'
responses = depl.post(on='/foo', inputs=da, request_size=1, return_responses=True, continue_on_error=True, results_in_order=True)
assert len(responses) == 50 # 1 request per input
num_failed_requests = 0
for r in responses:
if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR:
num_failed_requests += 1

assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing

0 comments on commit 48e4437

Please sign in to comment.