Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set prefetch in client for traffic control #897

Merged
merged 8 commits into from
Mar 7, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 67 additions & 27 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def _prepare_streaming(self, disable, total):
os.environ['JINA_GRPC_SEND_BYTES'] = '0'
os.environ['JINA_GRPC_RECV_BYTES'] = '0'

self._s_task = self._pbar.add_task(
':arrow_up: Send', total=total, total_size=0, start=False
)
# self._s_task = self._pbar.add_task(
ZiniuYu marked this conversation as resolved.
Show resolved Hide resolved
# ':arrow_up: Send', total=total, total_size=0, start=False
# )
self._r_task = self._pbar.add_task(
':arrow_down: Recv', total=total, total_size=0, start=False
)
Expand All @@ -174,8 +174,8 @@ def _iter_doc(
from rich import filesize
from docarray import Document

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)
# if hasattr(self, '_pbar'):
# self._pbar.start_task(self._s_task)

for c in content:
if isinstance(c, str):
Expand All @@ -199,16 +199,16 @@ def _iter_doc(
else:
raise TypeError(f'unsupported input type {c!r}')

if hasattr(self, '_pbar'):
self._pbar.update(
self._s_task,
advance=1,
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
)
),
)
# if hasattr(self, '_pbar'):
# self._pbar.update(
# self._s_task,
# advance=1,
# total_size=str(
# filesize.decimal(
# int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
# )
# ),
# )

if results is not None:
results.append(d)
Expand Down Expand Up @@ -251,6 +251,7 @@ def encode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
) -> 'np.ndarray':
"""Encode images and texts into embeddings where the input is an iterable of raw strings.
Each image and text must be represented as a string. The following strings are acceptable:
Expand All @@ -268,6 +269,8 @@ def encode(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after completion of each request.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight requests made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -283,6 +286,7 @@ def encode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
) -> 'DocumentArray':
"""Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`.
:param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`.
Expand All @@ -295,6 +299,8 @@ def encode(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after completion of each request.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight requests made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -314,6 +320,7 @@ def encode(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 1000)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -334,6 +341,7 @@ def encode(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just inline the code like: prefecth=kwargs.get('prefetch', 1000). You can reduce the number of lines if you do this to other params also.

)

unbox = hasattr(content, '__len__') and isinstance(content[0], str)
Expand All @@ -350,6 +358,7 @@ async def aencode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
) -> 'np.ndarray':
...

Expand All @@ -364,6 +373,7 @@ async def aencode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
) -> 'DocumentArray':
...

Expand All @@ -382,6 +392,7 @@ async def aencode(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 1000)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -402,6 +413,7 @@ async def aencode(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand All @@ -414,25 +426,25 @@ def _iter_rank_docs(
from rich import filesize
from docarray import Document

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)
# if hasattr(self, '_pbar'):
# self._pbar.start_task(self._s_task)

for c in content:
if isinstance(c, Document):
d = self._prepare_rank_doc(c, source)
else:
raise TypeError(f'Unsupported input type {c!r}')

if hasattr(self, '_pbar'):
self._pbar.update(
self._s_task,
advance=1,
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
)
),
)
# if hasattr(self, '_pbar'):
# self._pbar.update(
# self._s_task,
# advance=1,
# total_size=str(
# filesize.decimal(
# int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
# )
# ),
# )

if results is not None:
results.append(d)
Expand Down Expand Up @@ -498,6 +510,7 @@ def rank(
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 1000)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -516,6 +529,7 @@ def rank(
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

return results
Expand All @@ -533,6 +547,7 @@ async def arank(
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 1000)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -551,6 +566,7 @@ async def arank(
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand All @@ -567,6 +583,7 @@ def index(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
):
"""Index the images or texts where their embeddings are computed by the server CLIP model.

Expand All @@ -585,6 +602,8 @@ def index(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight requests made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -600,6 +619,7 @@ def index(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
) -> 'DocumentArray':
"""Index the images or texts where their embeddings are computed by the server CLIP model.

Expand All @@ -613,6 +633,8 @@ def index(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight requests made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -630,6 +652,7 @@ def index(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 1000)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -649,6 +672,7 @@ def index(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

return results
Expand All @@ -664,6 +688,7 @@ async def aindex(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
):
...

Expand All @@ -678,6 +703,7 @@ async def aindex(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
):
...

Expand All @@ -694,6 +720,7 @@ async def aindex(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 1000)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -713,6 +740,7 @@ async def aindex(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand All @@ -730,6 +758,7 @@ def search(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.

Expand All @@ -747,6 +776,8 @@ def search(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight requests made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
"""
...

Expand All @@ -762,6 +793,7 @@ def search(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.

Expand All @@ -779,6 +811,8 @@ def search(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight requests made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
"""
...

Expand All @@ -795,6 +829,7 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 1000)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -813,6 +848,7 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

return results
Expand All @@ -829,6 +865,7 @@ async def asearch(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
):
...

Expand All @@ -844,6 +881,7 @@ async def asearch(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 1000,
):
...

Expand All @@ -860,6 +898,7 @@ async def asearch(self, content, limit: int = 10, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 1000)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -878,6 +917,7 @@ async def asearch(self, content, limit: int = 10, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand Down