Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set prefetch in client for traffic control #897

Merged
merged 8 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,8 @@ def _prepare_streaming(self, disable, total):
os.environ['JINA_GRPC_SEND_BYTES'] = '0'
os.environ['JINA_GRPC_RECV_BYTES'] = '0'

self._s_task = self._pbar.add_task(
':arrow_up: Send', total=total, total_size=0, start=False
)
self._r_task = self._pbar.add_task(
':arrow_down: Recv', total=total, total_size=0, start=False
':arrow_down: Progress', total=total, total_size=0, start=False
)

@staticmethod
Expand All @@ -171,12 +168,8 @@ def _gather_result(
def _iter_doc(
self, content, results: Optional['DocumentArray'] = None
) -> Generator['Document', None, None]:
from rich import filesize
from docarray import Document

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)

for c in content:
if isinstance(c, str):
_mime = mimetypes.guess_type(c)[0]
Expand All @@ -199,17 +192,6 @@ def _iter_doc(
else:
raise TypeError(f'unsupported input type {c!r}')

if hasattr(self, '_pbar'):
self._pbar.update(
self._s_task,
advance=1,
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
)
),
)

if results is not None:
results.append(d)
yield d
Expand Down Expand Up @@ -251,6 +233,7 @@ def encode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'np.ndarray':
"""Encode images and texts into embeddings where the input is an iterable of raw strings.
Each image and text must be represented as a string. The following strings are acceptable:
Expand All @@ -268,6 +251,8 @@ def encode(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after completion of each request.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -283,6 +268,7 @@ def encode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`.
:param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`.
Expand All @@ -295,6 +281,8 @@ def encode(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after completion of each request.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -314,6 +302,7 @@ def encode(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -334,6 +323,7 @@ def encode(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just inline the code like: prefecth=kwargs.get('prefetch', 1000). You can reduce the number of lines if you do this to other params also.

)

unbox = hasattr(content, '__len__') and isinstance(content[0], str)
Expand All @@ -350,6 +340,7 @@ async def aencode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'np.ndarray':
...

Expand All @@ -364,6 +355,7 @@ async def aencode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
...

Expand All @@ -382,6 +374,7 @@ async def aencode(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -402,6 +395,7 @@ async def aencode(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand All @@ -411,29 +405,13 @@ async def aencode(self, content, **kwargs):
def _iter_rank_docs(
self, content, results: Optional['DocumentArray'] = None, source='matches'
) -> Generator['Document', None, None]:
from rich import filesize
from docarray import Document

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)

for c in content:
if isinstance(c, Document):
d = self._prepare_rank_doc(c, source)
else:
raise TypeError(f'Unsupported input type {c!r}')

if hasattr(self, '_pbar'):
self._pbar.update(
self._s_task,
advance=1,
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
)
),
)

if results is not None:
results.append(d)
yield d
Expand Down Expand Up @@ -498,6 +476,7 @@ def rank(
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -516,6 +495,7 @@ def rank(
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

return results
Expand All @@ -533,6 +513,7 @@ async def arank(
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -551,6 +532,7 @@ async def arank(
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand All @@ -567,6 +549,7 @@ def index(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
"""Index the images or texts where their embeddings are computed by the server CLIP model.

Expand All @@ -585,6 +568,8 @@ def index(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -600,6 +585,7 @@ def index(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Index the images or texts where their embeddings are computed by the server CLIP model.

Expand All @@ -613,6 +599,8 @@ def index(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -630,6 +618,7 @@ def index(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -649,6 +638,7 @@ def index(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

return results
Expand All @@ -664,6 +654,7 @@ async def aindex(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...

Expand All @@ -678,6 +669,7 @@ async def aindex(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...

Expand All @@ -694,6 +686,7 @@ async def aindex(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -713,6 +706,7 @@ async def aindex(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand All @@ -730,6 +724,7 @@ def search(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.

Expand All @@ -747,6 +742,8 @@ def search(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
"""
...

Expand All @@ -762,6 +759,7 @@ def search(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.

Expand All @@ -779,6 +777,8 @@ def search(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
"""
...

Expand All @@ -795,6 +795,7 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -813,6 +814,7 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

return results
Expand All @@ -829,6 +831,7 @@ async def asearch(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...

Expand All @@ -844,6 +847,7 @@ async def asearch(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...

Expand All @@ -860,6 +864,7 @@ async def asearch(self, content, limit: int = 10, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -878,6 +883,7 @@ async def asearch(self, content, limit: int = 10, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand Down
13 changes: 11 additions & 2 deletions docs/user-guides/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,14 @@ You can specify `.encode(..., batch_size=8)` to control how many `Document`s are

Intuitively, setting `batch_size=1024` should result in very high GPU utilization on each request. However, a large batch size like this also means sending each request would take longer. Given that `clip-client` is designed with request and response streaming, large batch size would not benefit from the time overlapping between request streaming and response streaming.

### Control prefetch size

To control the number of in-flight batches, you can use the `.encode(..., prefetch=100)` option.
The way this works is that when you send a large request, the outgoing request stream will usually finish before the incoming response stream due to the asynchronous design.
This is because the request handling is typically time-consuming, which can prevent the server from sending back the response and may cause it to close the connection as it thinks the incoming channel is idle.
By default, the client is set to a prefetch value of 100. However, it is recommended to use a lower value for expensive operations and a higher value for faster response times.

For more information about client prefetching, please refer to [Rate Limit](https://docs.jina.ai/concepts/client/rate-limit/) in Jina documentation.

### Show progressbar

Expand Down Expand Up @@ -459,8 +467,9 @@ Here are some suggestions when encoding a large number of `Document`s:

c.encode(iglob('**/*.png'))
```
2. Adjust `batch_size`.
3. Turn on the progressbar.
2. Adjust the `batch_size` parameters.
3. Adjust the `prefetch` parameters.
4. Turn on the progressbar.

````{danger}
In any case, avoiding the following coding:
Expand Down
Binary file modified docs/user-guides/images/client-pgbar.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.