-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add traversal paths #750
Changes from 14 commits
8082eec
1e3c2aa
2933e2e
07186dd
780bb36
e50b477
dff3e7b
d232b8e
0877128
6561225
49eee18
726c661
5a7f205
2331d91
ada03bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -22,9 +22,9 @@ def __init__( | |||||
use_default_preprocessing: bool = True, | ||||||
max_length: int = 77, | ||||||
device: str = 'cpu', | ||||||
overwrite_embeddings: bool = False, | ||||||
num_worker_preprocess: int = 4, | ||||||
minibatch_size: int = 32, | ||||||
traversal_paths: str = '@r', | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
*args, | ||||||
**kwargs, | ||||||
): | ||||||
|
@@ -52,12 +52,15 @@ def __init__( | |||||
:param num_worker_preprocess: Number of cpu processes used in preprocessing step. | ||||||
:param minibatch_size: Default batch size for encoding, used if the | ||||||
batch size is not passed as a parameter with the request. | ||||||
:param traversal_paths: Default traversal paths for encoding, used if | ||||||
the traversal path is not passed as a parameter with the request. | ||||||
Comment on lines
+55
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
super().__init__(*args, **kwargs) | ||||||
self._minibatch_size = minibatch_size | ||||||
|
||||||
self._use_default_preprocessing = use_default_preprocessing | ||||||
self._max_length = max_length | ||||||
self._traversal_paths = traversal_paths | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# self.device = device | ||||||
if not device: | ||||||
|
@@ -110,32 +113,36 @@ def _preproc_images(self, docs: 'DocumentArray'): | |||||
name='preprocess_images_seconds', | ||||||
documentation='images preprocess time in seconds', | ||||||
): | ||||||
tensors_batch = [] | ||||||
if self._use_default_preprocessing: | ||||||
tensors_batch = [] | ||||||
|
||||||
for d in docs: | ||||||
content = d.content | ||||||
for d in docs: | ||||||
content = d.content | ||||||
|
||||||
if d.blob: | ||||||
d.convert_blob_to_image_tensor() | ||||||
elif d.uri: | ||||||
d.load_uri_to_image_tensor() | ||||||
if d.blob: | ||||||
d.convert_blob_to_image_tensor() | ||||||
elif d.tensor is None and d.uri: | ||||||
# in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri | ||||||
d.load_uri_to_image_tensor() | ||||||
|
||||||
tensors_batch.append(d.tensor) | ||||||
tensors_batch.append(d.tensor) | ||||||
|
||||||
# recover content | ||||||
d.content = content | ||||||
# recover content | ||||||
d.content = content | ||||||
|
||||||
if self._use_default_preprocessing: | ||||||
batch_data = self._vision_preprocessor( | ||||||
images=tensors_batch, | ||||||
return_tensors='pt', | ||||||
) | ||||||
batch_data = {k: v.to(self._device) for k, v in batch_data.items()} | ||||||
batch_data = { | ||||||
k: v.type(torch.float32).to(self._device) | ||||||
for k, v in batch_data.items() | ||||||
} | ||||||
|
||||||
else: | ||||||
batch_data = { | ||||||
'pixel_values': torch.tensor( | ||||||
tensors_batch, dtype=torch.float32, device=self._device | ||||||
docs.tensors, dtype=torch.float32, device=self._device | ||||||
) | ||||||
} | ||||||
|
||||||
|
@@ -163,7 +170,7 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): | |||||
set_rank(docs) | ||||||
|
||||||
@requests | ||||||
async def encode(self, docs: DocumentArray, **kwargs): | ||||||
async def encode(self, docs: DocumentArray, parameters: Dict = {}, **kwargs): | ||||||
""" | ||||||
Encode all documents with `text` or image content using the corresponding CLIP | ||||||
encoder. Store the embeddings in the `embedding` attribute. | ||||||
|
@@ -181,18 +188,25 @@ async def encode(self, docs: DocumentArray, **kwargs): | |||||
the CLIP model was trained on images of the size ``224 x 224``, and that | ||||||
they are of the shape ``[3, H, W]`` with ``dtype=float32``. They should | ||||||
also be normalized (values between 0 and 1). | ||||||
:param parameters: A dictionary that contains parameters to control encoding. | ||||||
The accepted keys are ``traversal_paths`` and ``minibatch_size`` - in their | ||||||
absence their corresponding default values are used. | ||||||
""" | ||||||
|
||||||
traversal_paths = parameters.get('traversal_paths', self._traversal_paths) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hanxiao I don't agree with this suggestion. It will break the following use case:
It is impossible to pass the proper parameters:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By defining the default traversal path in
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why it is impossible? i dont get it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hanxiao From the above flow example, there are two encoders both working on different level documents (one on root-level, another on chunk-level).
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but you should be able to send parameter to one particular Executor |
||||||
minibatch_size = parameters.get('minibatch_size', self._minibatch_size) | ||||||
|
||||||
_img_da = DocumentArray() | ||||||
_txt_da = DocumentArray() | ||||||
for d in docs: | ||||||
for d in docs[traversal_paths]: | ||||||
split_img_txt_da(d, _img_da, _txt_da) | ||||||
|
||||||
with torch.inference_mode(): | ||||||
# for image | ||||||
if _img_da: | ||||||
for minibatch, batch_data in _img_da.map_batch( | ||||||
self._preproc_images, | ||||||
batch_size=self._minibatch_size, | ||||||
batch_size=minibatch_size, | ||||||
pool=self._pool, | ||||||
): | ||||||
with self.monitor( | ||||||
|
@@ -210,7 +224,7 @@ async def encode(self, docs: DocumentArray, **kwargs): | |||||
if _txt_da: | ||||||
for minibatch, batch_data in _txt_da.map_batch( | ||||||
self._preproc_texts, | ||||||
batch_size=self._minibatch_size, | ||||||
batch_size=minibatch_size, | ||||||
pool=self._pool, | ||||||
): | ||||||
with self.monitor( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.