Skip to content

Commit

Permalink
feat: add traversal paths (#750)
Browse files Browse the repository at this point in the history
* feat: add traversal paths

* fix: collate batch

* fix: parameters dict

* fix: rank parameters

* fix: change default minibatch size

* fix: support traversal in client

* fix: pass minibatch_size from client

* fix: unittest

* fix: error

* fix: unittest

* fix: tensorrt traversal paths

* fix: revert client

* fix: minor revision

* fix: clinet

* refactor: parameter rename and set default batch_size

Co-authored-by: Han Xiao <[email protected]>
  • Loading branch information
numb3r3 and hanxiao authored Jun 13, 2022
1 parent d5be8c2 commit e022bd4
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 132 deletions.
11 changes: 5 additions & 6 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,11 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
else:
yield Document(text=c)
elif isinstance(c, Document):
if c.content_type in ('text', 'blob'):
self._return_plain = False
self._return_plain = False
if c.content_type in ('text', 'blob', 'tensor'):
yield c
elif not c.blob and c.uri:
c.load_uri_to_blob()
self._return_plain = False
yield c
elif c.tensor is not None:
yield c
else:
raise TypeError(f'unsupported input type {c!r} {c.content_type}')
Expand All @@ -187,7 +184,9 @@ def _get_post_payload(self, content, kwargs):
return dict(
on='/',
inputs=self._iter_doc(content),
request_size=kwargs.get('batch_size', 8),
request_size=kwargs.get(
'batch_size', 8
), # the default `batch_size` is very subjective. i would set it 8 based on 2 considerations (1) play safe on most GPUs (2) ease the load to our demo server
total_docs=len(content) if hasattr(content, '__len__') else None,
)

Expand Down
60 changes: 37 additions & 23 deletions server/clip_server/executors/clip_hg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def __init__(
finetuned_checkpoint_path: Optional[str] = None,
base_feature_extractor: Optional[str] = None,
base_tokenizer_model: Optional[str] = None,
use_default_preprocessing: bool = True,
preprocessing: bool = True,
max_length: int = 77,
device: str = 'cpu',
overwrite_embeddings: bool = False,
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
traversal_paths: str = '@r',
*args,
**kwargs,
):
Expand All @@ -41,7 +41,7 @@ def __init__(
Defaults to ``pretrained_model_name_or_path`` if None.
:param base_tokenizer_model: Base tokenizer model.
Defaults to ``pretrained_model_name_or_path`` if None.
:param use_default_preprocessing: Whether to use the `base_feature_extractor`
:param preprocessing: Whether to use the `base_feature_extractor`
on images (tensors) before encoding them. If you disable this, you must
ensure that the images you pass in have the correct format, see the
``encode`` method for details.
Expand All @@ -52,12 +52,15 @@ def __init__(
:param num_worker_preprocess: Number of cpu processes used in preprocessing step.
:param minibatch_size: Default batch size for encoding, used if the
batch size is not passed as a parameter with the request.
:param traversal_paths: Default traversal paths for encoding, used if
the traversal path is not passed as a parameter with the request.
"""
super().__init__(*args, **kwargs)
self._minibatch_size = minibatch_size

self._use_default_preprocessing = use_default_preprocessing
self._preprocessing = preprocessing
self._max_length = max_length
self._traversal_paths = traversal_paths

# self.device = device
if not device:
Expand Down Expand Up @@ -110,32 +113,36 @@ def _preproc_images(self, docs: 'DocumentArray'):
name='preprocess_images_seconds',
documentation='images preprocess time in seconds',
):
tensors_batch = []
if self._preprocessing:
tensors_batch = []

for d in docs:
content = d.content
for d in docs:
content = d.content

if d.blob:
d.convert_blob_to_image_tensor()
elif d.uri:
d.load_uri_to_image_tensor()
if d.blob:
d.convert_blob_to_image_tensor()
elif d.tensor is None and d.uri:
# in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri
d.load_uri_to_image_tensor()

tensors_batch.append(d.tensor)
tensors_batch.append(d.tensor)

# recover content
d.content = content
# recover content
d.content = content

if self._use_default_preprocessing:
batch_data = self._vision_preprocessor(
images=tensors_batch,
return_tensors='pt',
)
batch_data = {k: v.to(self._device) for k, v in batch_data.items()}
batch_data = {
k: v.type(torch.float32).to(self._device)
for k, v in batch_data.items()
}

else:
batch_data = {
'pixel_values': torch.tensor(
tensors_batch, dtype=torch.float32, device=self._device
docs.tensors, dtype=torch.float32, device=self._device
)
}

Expand Down Expand Up @@ -163,7 +170,7 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
set_rank(docs)

@requests
async def encode(self, docs: DocumentArray, **kwargs):
async def encode(self, docs: DocumentArray, parameters: Dict = {}, **kwargs):
"""
Encode all documents with `text` or image content using the corresponding CLIP
encoder. Store the embeddings in the `embedding` attribute.
Expand All @@ -172,27 +179,34 @@ async def encode(self, docs: DocumentArray, **kwargs):
``tensor`` of the
shape ``Height x Width x 3``. By default, the input ``tensor`` must
be an ``ndarray`` with ``dtype=uint8`` or ``dtype=float32``.
If you set ``use_default_preprocessing=True`` when creating this encoder,
If you set ``preprocessing=True`` when creating this encoder,
then the ``tensor`` arrays should have the shape ``[H, W, 3]``, and be in
the RGB color format with ``dtype=uint8``.
If you set ``use_default_preprocessing=False`` when creating this encoder,
If you set ``preprocessing=False`` when creating this encoder,
then you need to ensure that the images you pass in are already
pre-processed. This means that they are all the same size (for batching) -
the CLIP model was trained on images of the size ``224 x 224``, and that
they are of the shape ``[3, H, W]`` with ``dtype=float32``. They should
also be normalized (values between 0 and 1).
:param parameters: A dictionary that contains parameters to control encoding.
The accepted keys are ``traversal_paths`` and ``minibatch_size`` - in their
absence their corresponding default values are used.
"""

traversal_paths = parameters.get('traversal_paths', self._traversal_paths)
minibatch_size = parameters.get('minibatch_size', self._minibatch_size)

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
for d in docs[traversal_paths]:
split_img_txt_da(d, _img_da, _txt_da)

with torch.inference_mode():
# for image
if _img_da:
for minibatch, batch_data in _img_da.map_batch(
self._preproc_images,
batch_size=self._minibatch_size,
batch_size=minibatch_size,
pool=self._pool,
):
with self.monitor(
Expand All @@ -210,7 +224,7 @@ async def encode(self, docs: DocumentArray, **kwargs):
if _txt_da:
for minibatch, batch_data in _txt_da.map_batch(
self._preproc_texts,
batch_size=self._minibatch_size,
batch_size=minibatch_size,
pool=self._pool,
):
with self.monitor(
Expand Down
51 changes: 19 additions & 32 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ def __init__(
name: str = 'ViT-B/32',
device: Optional[str] = None,
num_worker_preprocess: int = 4,
minibatch_size: int = 16,
minibatch_size: int = 32,
traversal_paths: str = '@r',
**kwargs,
):
super().__init__(**kwargs)

self._minibatch_size = minibatch_size
self._traversal_paths = traversal_paths

self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._pool = ThreadPool(processes=num_worker_preprocess)

self._minibatch_size = minibatch_size

self._model = CLIPOnnxModel(name)

import torch
Expand Down Expand Up @@ -59,7 +61,7 @@ def __init__(
and hasattr(self.runtime_args, 'replicas')
):
replicas = getattr(self.runtime_args, 'replicas', 1)
num_threads = max(1, torch.get_num_threads() // replicas)
num_threads = max(1, torch.get_num_threads() * 2 // replicas)
if num_threads < 2:
warnings.warn(
f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in '
Expand Down Expand Up @@ -98,55 +100,40 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
set_rank(docs)

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):

traversal_paths = parameters.get('traversal_paths', self._traversal_paths)
minibatch_size = parameters.get('minibatch_size', self._minibatch_size)

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
for d in docs[traversal_paths]:
split_img_txt_da(d, _img_da, _txt_da)

# for image
if _img_da:
for minibatch, _contents in _img_da.map_batch(
for minibatch, batch_data in _img_da.map_batch(
self._preproc_images,
batch_size=self._minibatch_size,
batch_size=minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_images_seconds',
documentation='images encode time in seconds',
):
minibatch.embeddings = self._model.encode_image(minibatch.tensors)

# recover original content
try:
_ = iter(_contents)
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct
except TypeError:
pass
minibatch.embeddings = self._model.encode_image(batch_data)

# for text
# for text
if _txt_da:
for minibatch, _contents in _txt_da.map_batch(
for minibatch, batch_data in _txt_da.map_batch(
self._preproc_texts,
batch_size=self._minibatch_size,
batch_size=minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_texts_seconds',
documentation='texts encode time in seconds',
):
minibatch.embeddings = self._model.encode_text(minibatch.tensors)

# recover original content
try:
_ = iter(_contents)
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct
except TypeError:
pass

# drop tensors
docs.tensors = None
minibatch.embeddings = self._model.encode_text(batch_data)

return docs
43 changes: 15 additions & 28 deletions server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def __init__(
name: str = 'ViT-B/32',
device: str = 'cuda',
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
minibatch_size: int = 32,
traversal_paths: str = '@r',
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -28,6 +29,8 @@ def __init__(
self._pool = ThreadPool(processes=num_worker_preprocess)

self._minibatch_size = minibatch_size
self._traversal_paths = traversal_paths

self._device = device

import torch
Expand Down Expand Up @@ -71,67 +74,51 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
set_rank(docs)

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
traversal_paths = parameters.get('traversal_paths', self._traversal_paths)
minibatch_size = parameters.get('minibatch_size', self._minibatch_size)

_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
for d in docs[traversal_paths]:
split_img_txt_da(d, _img_da, _txt_da)

# for image
if _img_da:
for minibatch, _contents in _img_da.map_batch(
for minibatch, batch_data in _img_da.map_batch(
self._preproc_images,
batch_size=self._minibatch_size,
batch_size=minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_images_seconds',
documentation='images encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_image(minibatch.tensors)
self._model.encode_image(batch_data)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)

# recover original content
try:
_ = iter(_contents)
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct
except TypeError:
pass

# for text
if _txt_da:
for minibatch, _contents in _txt_da.map_batch(
for minibatch, batch_data in _txt_da.map_batch(
self._preproc_texts,
batch_size=self._minibatch_size,
batch_size=minibatch_size,
pool=self._pool,
):
with self.monitor(
name='encode_texts_seconds',
documentation='texts encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_text(minibatch.tensors)
self._model.encode_text(batch_data)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)

# recover original content
try:
_ = iter(_contents)
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct
except TypeError:
pass

# drop tensors
docs.tensors = None

return docs
Loading

0 comments on commit e022bd4

Please sign in to comment.