Skip to content

Commit

Permalink
fix(server): recover original contents of the input da (#726)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored May 23, 2022
1 parent 42ef75b commit a7311fb
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 21 deletions.
16 changes: 12 additions & 4 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Optional, Dict

import onnxruntime as ort
from jina import Executor, requests, DocumentArray

from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
Expand All @@ -13,7 +15,6 @@
)
from clip_server.model import clip
from clip_server.model.clip_onnx import CLIPOnnxModel
from jina import Executor, requests, DocumentArray


class CLIPEncoder(Executor):
Expand Down Expand Up @@ -93,24 +94,31 @@ async def encode(self, docs: 'DocumentArray', **kwargs):

# for image
if _img_da:
for minibatch in _img_da.map_batch(
for minibatch, _contents in _img_da.map_batch(
partial(
preproc_image, preprocess_fn=self._preprocess_tensor, return_np=True
),
batch_size=self._minibatch_size,
pool=self._pool,
):
minibatch.embeddings = self._model.encode_image(minibatch.tensors)
# recover original content
if _contents:
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct

# for text
if _txt_da:
for minibatch, _texts in _txt_da.map_batch(
for minibatch, _contents in _txt_da.map_batch(
partial(preproc_text, return_np=True),
batch_size=self._minibatch_size,
pool=self._pool,
):
minibatch.embeddings = self._model.encode_text(minibatch.tensors)
minibatch.texts = _texts
# recover original content
if _contents:
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct

# drop tensors
docs.tensors = None
Expand Down
16 changes: 12 additions & 4 deletions server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Dict

import numpy as np
from jina import Executor, requests, DocumentArray

from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
Expand All @@ -11,7 +13,6 @@
)
from clip_server.model import clip
from clip_server.model.clip_trt import CLIPTensorRTModel
from jina import Executor, requests, DocumentArray


class CLIPEncoder(Executor):
Expand Down Expand Up @@ -61,7 +62,7 @@ async def encode(self, docs: 'DocumentArray', **kwargs):

# for image
if _img_da:
for minibatch in _img_da.map_batch(
for minibatch, _contents in _img_da.map_batch(
partial(
preproc_image,
preprocess_fn=self._preprocess_tensor,
Expand All @@ -78,10 +79,14 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
.numpy()
.astype(np.float32)
)
# recover original content
if _contents:
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct

# for text
if _txt_da:
for minibatch, _texts in _txt_da.map_batch(
for minibatch, _contents in _txt_da.map_batch(
partial(preproc_text, device=self._device, return_np=False),
batch_size=self._minibatch_size,
pool=self._pool,
Expand All @@ -93,7 +98,10 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
.numpy()
.astype(np.float32)
)
minibatch.texts = _texts
# recover original content
if _contents:
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct

# drop tensors
docs.tensors = None
Expand Down
18 changes: 14 additions & 4 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

import numpy as np
import torch
from jina import Executor, requests, DocumentArray

from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
preproc_text,
set_rank,
)
from clip_server.model import clip
from jina import Executor, requests, DocumentArray


class CLIPEncoder(Executor):
Expand Down Expand Up @@ -76,7 +77,7 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
with torch.inference_mode():
# for image
if _img_da:
for minibatch in _img_da.map_batch(
for minibatch, _contents in _img_da.map_batch(
partial(
preproc_image,
preprocess_fn=self._preprocess_tensor,
Expand All @@ -93,9 +94,14 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
.astype(np.float32)
)

# recover original content
if _contents:
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct

# for text
if _txt_da:
for minibatch, _texts in _txt_da.map_batch(
for minibatch, _contents in _txt_da.map_batch(
partial(preproc_text, device=self._device, return_np=False),
batch_size=self._minibatch_size,
pool=self._pool,
Expand All @@ -106,7 +112,11 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
.numpy()
.astype(np.float32)
)
minibatch.texts = _texts

# recover original content
if _contents:
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct

# drop tensors
docs.tensors = None
Expand Down
22 changes: 13 additions & 9 deletions server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, List, Callable
from typing import Tuple, List, Callable, Any

import numpy as np
from docarray import Document, DocumentArray
Expand All @@ -20,7 +20,9 @@ def preproc_image(
preprocess_fn: Callable,
device: str = 'cpu',
return_np: bool = False,
) -> 'DocumentArray':
) -> Tuple['DocumentArray', List[Any]]:
contents = da.contents

for d in da:
if d.blob:
d.convert_blob_to_image_tensor()
Expand All @@ -34,22 +36,24 @@ def preproc_image(
da.tensors = da.tensors.cpu().numpy().astype(np.float32)
else:
da.tensors = da.tensors.to(device)
return da

return da, contents


def preproc_text(
da: 'DocumentArray', device: str = 'cpu', return_np: bool = False
) -> Tuple['DocumentArray', List[str]]:
texts = da.texts
da.tensors = clip.tokenize(texts).detach()
) -> Tuple['DocumentArray', List[Any]]:
contents = da.contents

da.tensors = clip.tokenize(contents).detach()

if return_np:
da.tensors = da.tensors.cpu().numpy().astype(np.int64)
else:
da.tensors = da.tensors.to(device)

da[:, 'mime_type'] = 'text'
return da, texts
return da, contents


def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'):
Expand Down Expand Up @@ -90,10 +94,10 @@ def set_rank(docs, _logit_scale=np.exp(4.60517)):

start_idx = end_idx

_candidates.embeddings = None # remove embedding to save bandwidth

final = sorted(
_candidates, key=lambda _m: _m.scores['clip_score'].value, reverse=True
)

final.embeddings = None # remove embedding to save bandwidth

q.matches = final
25 changes: 25 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,28 @@ def test_docarray_inputs(make_flow, inputs, port_generator):
r = c.encode(inputs if not callable(inputs) else inputs())
assert isinstance(r, DocumentArray)
assert r.embeddings.shape


@pytest.mark.parametrize(
'inputs',
[
DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]),
DocumentArray(
[
Document(
uri='https://docarray.jina.ai/_static/favicon.png',
text='hello, world',
),
]
),
DocumentArray.from_files(
f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
),
],
)
def test_docarray_preserve_original_inputs(make_flow, inputs, port_generator):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
r = c.encode(inputs if not callable(inputs) else inputs())
assert isinstance(r, DocumentArray)
assert r.embeddings.shape
assert r.contents == inputs.contents

0 comments on commit a7311fb

Please sign in to comment.