-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add traversal paths #750
Changes from 12 commits
8082eec
1e3c2aa
2933e2e
07186dd
780bb36
e50b477
dff3e7b
d232b8e
0877128
6561225
49eee18
726c661
5a7f205
2331d91
ada03bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -132,7 +132,7 @@ def _gather_result(self, r): | |||||
|
||||||
@property | ||||||
def _unboxed_result(self): | ||||||
if self._results.embeddings is None: | ||||||
if self._return_plain and self._results.embeddings is None: | ||||||
raise ValueError( | ||||||
'empty embedding returned from the server. ' | ||||||
'This often due to a mis-config of the server, ' | ||||||
|
@@ -158,14 +158,11 @@ def _iter_doc(self, content) -> Generator['Document', None, None]: | |||||
else: | ||||||
yield Document(text=c) | ||||||
elif isinstance(c, Document): | ||||||
if c.content_type in ('text', 'blob'): | ||||||
self._return_plain = False | ||||||
self._return_plain = False | ||||||
if c.content_type in ('text', 'blob', 'tensor'): | ||||||
yield c | ||||||
elif not c.blob and c.uri: | ||||||
c.load_uri_to_blob() | ||||||
self._return_plain = False | ||||||
yield c | ||||||
elif c.tensor is not None: | ||||||
yield c | ||||||
else: | ||||||
raise TypeError(f'unsupported input type {c!r} {c.content_type}') | ||||||
|
@@ -184,10 +181,15 @@ def _iter_doc(self, content) -> Generator['Document', None, None]: | |||||
) | ||||||
|
||||||
def _get_post_payload(self, content, kwargs): | ||||||
parameters = {} | ||||||
if 'batch_size' in kwargs: | ||||||
parameters['minibatch_size'] = kwargs['batch_size'] | ||||||
|
||||||
return dict( | ||||||
on='/', | ||||||
inputs=self._iter_doc(content), | ||||||
request_size=kwargs.get('batch_size', 8), | ||||||
request_size=kwargs.get('batch_size', 32), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
parameters=parameters, | ||||||
numb3r3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
total_docs=len(content) if hasattr(content, '__len__') else None, | ||||||
) | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -22,9 +22,9 @@ def __init__( | |||||
use_default_preprocessing: bool = True, | ||||||
max_length: int = 77, | ||||||
device: str = 'cpu', | ||||||
overwrite_embeddings: bool = False, | ||||||
num_worker_preprocess: int = 4, | ||||||
minibatch_size: int = 32, | ||||||
traversal_paths: str = '@r', | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
*args, | ||||||
**kwargs, | ||||||
): | ||||||
|
@@ -52,12 +52,15 @@ def __init__( | |||||
:param num_worker_preprocess: Number of cpu processes used in preprocessing step. | ||||||
:param minibatch_size: Default batch size for encoding, used if the | ||||||
batch size is not passed as a parameter with the request. | ||||||
:param traversal_paths: Default traversal paths for encoding, used if | ||||||
the traversal path is not passed as a parameter with the request. | ||||||
Comment on lines
+55
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
super().__init__(*args, **kwargs) | ||||||
self._minibatch_size = minibatch_size | ||||||
|
||||||
self._use_default_preprocessing = use_default_preprocessing | ||||||
self._max_length = max_length | ||||||
self._traversal_paths = traversal_paths | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# self.device = device | ||||||
if not device: | ||||||
|
@@ -110,32 +113,36 @@ def _preproc_images(self, docs: 'DocumentArray'): | |||||
name='preprocess_images_seconds', | ||||||
documentation='images preprocess time in seconds', | ||||||
): | ||||||
tensors_batch = [] | ||||||
if self._use_default_preprocessing: | ||||||
tensors_batch = [] | ||||||
|
||||||
for d in docs: | ||||||
content = d.content | ||||||
for d in docs: | ||||||
content = d.content | ||||||
|
||||||
if d.blob: | ||||||
d.convert_blob_to_image_tensor() | ||||||
elif d.uri: | ||||||
d.load_uri_to_image_tensor() | ||||||
if d.blob: | ||||||
d.convert_blob_to_image_tensor() | ||||||
elif d.tensor is None and d.uri: | ||||||
# in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri | ||||||
d.load_uri_to_image_tensor() | ||||||
|
||||||
tensors_batch.append(d.tensor) | ||||||
tensors_batch.append(d.tensor) | ||||||
|
||||||
# recover content | ||||||
d.content = content | ||||||
# recover content | ||||||
d.content = content | ||||||
|
||||||
if self._use_default_preprocessing: | ||||||
batch_data = self._vision_preprocessor( | ||||||
images=tensors_batch, | ||||||
return_tensors='pt', | ||||||
) | ||||||
batch_data = {k: v.to(self._device) for k, v in batch_data.items()} | ||||||
batch_data = { | ||||||
k: v.type(torch.float32).to(self._device) | ||||||
for k, v in batch_data.items() | ||||||
} | ||||||
|
||||||
else: | ||||||
batch_data = { | ||||||
'pixel_values': torch.tensor( | ||||||
tensors_batch, dtype=torch.float32, device=self._device | ||||||
docs.tensors, dtype=torch.float32, device=self._device | ||||||
) | ||||||
} | ||||||
|
||||||
|
@@ -163,7 +170,7 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): | |||||
set_rank(docs) | ||||||
|
||||||
@requests | ||||||
async def encode(self, docs: DocumentArray, **kwargs): | ||||||
async def encode(self, docs: DocumentArray, parameters: Dict = {}, **kwargs): | ||||||
""" | ||||||
Encode all documents with `text` or image content using the corresponding CLIP | ||||||
encoder. Store the embeddings in the `embedding` attribute. | ||||||
|
@@ -181,18 +188,25 @@ async def encode(self, docs: DocumentArray, **kwargs): | |||||
the CLIP model was trained on images of the size ``224 x 224``, and that | ||||||
they are of the shape ``[3, H, W]`` with ``dtype=float32``. They should | ||||||
also be normalized (values between 0 and 1). | ||||||
:param parameters: A dictionary that contains parameters to control encoding. | ||||||
The accepted keys are ``traversal_paths`` and ``minibatch_size`` - in their | ||||||
absence their corresponding default values are used. | ||||||
""" | ||||||
|
||||||
traversal_paths = parameters.get('traversal_paths', self._traversal_paths) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hanxiao I don't agree with this suggestion. It will break the following use case:
It is impossible to pass the proper parameters:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By defining the default traversal path in
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why it is impossible? i dont get it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hanxiao From the above flow example, there are two encoders both working on different level documents (one on root-level, another on chunk-level).
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but you should be able to send parameter to one particular Executor |
||||||
minibatch_size = parameters.get('minibatch_size', self._minibatch_size) | ||||||
|
||||||
_img_da = DocumentArray() | ||||||
_txt_da = DocumentArray() | ||||||
for d in docs: | ||||||
for d in docs[traversal_paths]: | ||||||
split_img_txt_da(d, _img_da, _txt_da) | ||||||
|
||||||
with torch.inference_mode(): | ||||||
# for image | ||||||
if _img_da: | ||||||
for minibatch, batch_data in _img_da.map_batch( | ||||||
self._preproc_images, | ||||||
batch_size=self._minibatch_size, | ||||||
batch_size=minibatch_size, | ||||||
pool=self._pool, | ||||||
): | ||||||
with self.monitor( | ||||||
|
@@ -210,7 +224,7 @@ async def encode(self, docs: DocumentArray, **kwargs): | |||||
if _txt_da: | ||||||
for minibatch, batch_data in _txt_da.map_batch( | ||||||
self._preproc_texts, | ||||||
batch_size=self._minibatch_size, | ||||||
batch_size=minibatch_size, | ||||||
pool=self._pool, | ||||||
): | ||||||
with self.monitor( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i disagree the exposing
minibatch_size
to public client. It can easily overload a CAS server. Imagine user now has the capability of controlling bothrequest_size
andminibatch_size
, the user can easily occupy the full GPU usage on our Berlin GPU server. It can easily make our GPU OOM by setting largerequest_size
andminibatch_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In C-S architecture, one should not aim to expose every server args to client, it is very risky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, that makes sense. Then we need to update the document about how to control batch size.