Skip to content

Commit

Permalink
test: add extra tests for dynamic batching
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Dec 14, 2023
1 parent 94fff08 commit d5813b6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 5 deletions.
9 changes: 4 additions & 5 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ def _cancel_timer_if_pending(self):
def _start_timer(self):
self._cancel_timer_if_pending()
self._timer_task = asyncio.create_task(
self._sleep_then_set(self._flush_trigger)
self._sleep_then_set()
)
self._timer_started = True

async def _sleep_then_set(self, event: Event):
async def _sleep_then_set(self):
"""Sleep and then set the event
:param event: event to set
"""
await asyncio.sleep(self._timeout / 1000)
event.set()
self._flush_trigger.set()

Check warning on line 85 in jina/serve/runtimes/worker/batch_queue.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/batch_queue.py#L85

Added line #L85 was not covered by tests

async def push(self, request: DataRequest) -> asyncio.Queue:
"""Append request to the the list of requests to be processed.
Expand Down Expand Up @@ -220,7 +220,6 @@ def batch(iterable_1, iterable_2, n=1):
# communicate that the request has been processed properly. At this stage the data_lock is ours and
# therefore noone can add requests to this list.
self._flush_trigger: Event = Event()
self._cancel_timer_if_pending()
self._timer_task = None
try:
if not docarray_v2:
Expand Down Expand Up @@ -274,7 +273,7 @@ def batch(iterable_1, iterable_2, n=1):
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(batch_res_docs)
non_assigned_to_response_docs.extend(batch_res_docs or docs_inner_batch)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(

Check warning on line 278 in jina/serve/runtimes/worker/batch_queue.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/batch_queue.py#L276-L278

Added lines #L276 - L278 were not covered by tests
non_assigned_to_response_docs,
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/deployments/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,14 @@ def serve_depl(stop_event, **kwargs):
t.join()


@pytest.mark.repeat(10)
@pytest.mark.parametrize('served_depl', [False, True], indirect=True)
def test_deployment_dynamic_batching(served_depl, exposed_port):
docs = Client(port=exposed_port).post(on='/bar', inputs=DocumentArray.empty(5))
assert docs.texts == ['bar' for _ in docs]


@pytest.mark.repeat(10)
@pytest.mark.parametrize('enable_dynamic_batching', [False, True])
def test_deployment_client_dynamic_batching(enable_dynamic_batching):
kwargs = {'port': random_port()}
Expand Down
30 changes: 30 additions & 0 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
DocumentArray,
Executor,
Flow,
Deployment,
dynamic_batching,
requests,
)
from jina.clients.request import request_generator
from jina.serve.networking.utils import send_request_sync
from jina.serve.runtimes.servers import BaseServer
from jina_cli.api import executor_native
from jina.proto import jina_pb2
from tests.helper import _generate_pod_args

cur_dir = os.path.dirname(__file__)
Expand Down Expand Up @@ -311,6 +313,7 @@ def test_preferred_batch_size(add_parameters, use_stream):
assert time_taken < TIMEOUT_TOLERANCE


@pytest.mark.repeat(10)
@pytest.mark.parametrize('use_stream', [False, True])
def test_correctness(use_stream):
f = Flow().add(uses=PlaceholderExecutor)
Expand Down Expand Up @@ -492,6 +495,7 @@ def test_param_correctness(use_stream):
]
assert [doc.text for doc in results[2]] == [f'D{str(PARAM1)}']


@pytest.mark.parametrize(
'uses',
[
Expand Down Expand Up @@ -622,3 +626,29 @@ def test_failure_propagation():
'/wronglennone',
inputs=DocumentArray([Document(text=str(i)) for i in range(8)]),
)


@pytest.mark.repeat(10)
def test_exception_handling_in_dynamic_batch():
class SlowExecutorWithException(Executor):

@dynamic_batching(preferred_batch_size=3, timeout=1000)
@requests(on='/foo')
def foo(self, docs, **kwargs):
for doc in docs:
if doc.text == 'fail':
raise Exception('Fail is in the Batch')

depl = Deployment(uses=SlowExecutorWithException)

with depl:
da = DocumentArray([Document(text='good') for _ in range(50)])
da[4].text = 'fail'
responses = depl.post(on='/foo', inputs=da, request_size=1, return_responses=True, continue_on_error=True, results_in_order=True)
assert len(responses) == 50 # 1 request per input
num_failed_requests = 0
for r in responses:
if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR:
num_failed_requests += 1

assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing

0 comments on commit d5813b6

Please sign in to comment.