diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index fb0ee26ab..67d2dd29f 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -5,6 +5,8 @@ from typing import Optional, Dict import onnxruntime as ort +from jina import Executor, requests, DocumentArray + from clip_server.executors.helper import ( split_img_txt_da, preproc_image, @@ -13,7 +15,6 @@ ) from clip_server.model import clip from clip_server.model.clip_onnx import CLIPOnnxModel -from jina import Executor, requests, DocumentArray class CLIPEncoder(Executor): @@ -93,7 +94,7 @@ async def encode(self, docs: 'DocumentArray', **kwargs): # for image if _img_da: - for minibatch in _img_da.map_batch( + for minibatch, _contents in _img_da.map_batch( partial( preproc_image, preprocess_fn=self._preprocess_tensor, return_np=True ), @@ -101,16 +102,23 @@ async def encode(self, docs: 'DocumentArray', **kwargs): pool=self._pool, ): minibatch.embeddings = self._model.encode_image(minibatch.tensors) + # recover original content + if _contents: + for _d, _ct in zip(minibatch, _contents): + _d.content = _ct # for text if _txt_da: - for minibatch, _texts in _txt_da.map_batch( + for minibatch, _contents in _txt_da.map_batch( partial(preproc_text, return_np=True), batch_size=self._minibatch_size, pool=self._pool, ): minibatch.embeddings = self._model.encode_text(minibatch.tensors) - minibatch.texts = _texts + # recover original content + if _contents: + for _d, _ct in zip(minibatch, _contents): + _d.content = _ct # drop tensors docs.tensors = None diff --git a/server/clip_server/executors/clip_tensorrt.py b/server/clip_server/executors/clip_tensorrt.py index 721f3a457..f60c87dd6 100644 --- a/server/clip_server/executors/clip_tensorrt.py +++ b/server/clip_server/executors/clip_tensorrt.py @@ -3,6 +3,8 @@ from typing import Dict import numpy as np +from jina import Executor, requests, DocumentArray + from clip_server.executors.helper import ( split_img_txt_da, preproc_image, @@ -11,7 +13,6 @@ ) from clip_server.model import clip from clip_server.model.clip_trt import CLIPTensorRTModel -from jina import Executor, requests, DocumentArray class CLIPEncoder(Executor): @@ -61,7 +62,7 @@ async def encode(self, docs: 'DocumentArray', **kwargs): # for image if _img_da: - for minibatch in _img_da.map_batch( + for minibatch, _contents in _img_da.map_batch( partial( preproc_image, preprocess_fn=self._preprocess_tensor, @@ -78,10 +79,14 @@ async def encode(self, docs: 'DocumentArray', **kwargs): .numpy() .astype(np.float32) ) + # recover original content + if _contents: + for _d, _ct in zip(minibatch, _contents): + _d.content = _ct # for text if _txt_da: - for minibatch, _texts in _txt_da.map_batch( + for minibatch, _contents in _txt_da.map_batch( partial(preproc_text, device=self._device, return_np=False), batch_size=self._minibatch_size, pool=self._pool, @@ -93,7 +98,10 @@ async def encode(self, docs: 'DocumentArray', **kwargs): .numpy() .astype(np.float32) ) - minibatch.texts = _texts + # recover original content + if _contents: + for _d, _ct in zip(minibatch, _contents): + _d.content = _ct # drop tensors docs.tensors = None diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index 7a903b8be..2c1e05ad4 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -6,6 +6,8 @@ import numpy as np import torch +from jina import Executor, requests, DocumentArray + from clip_server.executors.helper import ( split_img_txt_da, preproc_image, @@ -13,7 +15,6 @@ set_rank, ) from clip_server.model import clip -from jina import Executor, requests, DocumentArray class CLIPEncoder(Executor): @@ -76,7 +77,7 @@ async def encode(self, docs: 'DocumentArray', **kwargs): with torch.inference_mode(): # for image if _img_da: - for minibatch in _img_da.map_batch( + for minibatch, _contents in _img_da.map_batch( partial( preproc_image, preprocess_fn=self._preprocess_tensor, @@ -93,9 +94,14 @@ async def encode(self, docs: 'DocumentArray', **kwargs): .astype(np.float32) ) + # recover original content + if _contents: + for _d, _ct in zip(minibatch, _contents): + _d.content = _ct + # for text if _txt_da: - for minibatch, _texts in _txt_da.map_batch( + for minibatch, _contents in _txt_da.map_batch( partial(preproc_text, device=self._device, return_np=False), batch_size=self._minibatch_size, pool=self._pool, @@ -106,7 +112,11 @@ async def encode(self, docs: 'DocumentArray', **kwargs): .numpy() .astype(np.float32) ) - minibatch.texts = _texts + + # recover original content + if _contents: + for _d, _ct in zip(minibatch, _contents): + _d.content = _ct # drop tensors docs.tensors = None diff --git a/server/clip_server/executors/helper.py b/server/clip_server/executors/helper.py index 8c3d9fa09..023861462 100644 --- a/server/clip_server/executors/helper.py +++ b/server/clip_server/executors/helper.py @@ -1,4 +1,4 @@ -from typing import Tuple, List, Callable +from typing import Tuple, List, Callable, Any import numpy as np from docarray import Document, DocumentArray @@ -20,7 +20,9 @@ def preproc_image( preprocess_fn: Callable, device: str = 'cpu', return_np: bool = False, -) -> 'DocumentArray': +) -> Tuple['DocumentArray', List[Any]]: + contents = da.contents + for d in da: if d.blob: d.convert_blob_to_image_tensor() @@ -34,14 +36,16 @@ def preproc_image( da.tensors = da.tensors.cpu().numpy().astype(np.float32) else: da.tensors = da.tensors.to(device) - return da + + return da, contents def preproc_text( da: 'DocumentArray', device: str = 'cpu', return_np: bool = False -) -> Tuple['DocumentArray', List[str]]: - texts = da.texts - da.tensors = clip.tokenize(texts).detach() +) -> Tuple['DocumentArray', List[Any]]: + contents = da.contents + + da.tensors = clip.tokenize(contents).detach() if return_np: da.tensors = da.tensors.cpu().numpy().astype(np.int64) @@ -49,7 +53,7 @@ def preproc_text( da.tensors = da.tensors.to(device) da[:, 'mime_type'] = 'text' - return da, texts + return da, contents def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'): @@ -90,10 +94,10 @@ def set_rank(docs, _logit_scale=np.exp(4.60517)): start_idx = end_idx + _candidates.embeddings = None # remove embedding to save bandwidth + final = sorted( _candidates, key=lambda _m: _m.scores['clip_score'].value, reverse=True ) - final.embeddings = None # remove embedding to save bandwidth - q.matches = final diff --git a/tests/test_simple.py b/tests/test_simple.py index 00132a14d..10ccbef8f 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -77,3 +77,28 @@ def test_docarray_inputs(make_flow, inputs, port_generator): r = c.encode(inputs if not callable(inputs) else inputs()) assert isinstance(r, DocumentArray) assert r.embeddings.shape + + +@pytest.mark.parametrize( + 'inputs', + [ + DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), + DocumentArray( + [ + Document( + uri='https://docarray.jina.ai/_static/favicon.png', + text='hello, world', + ), + ] + ), + DocumentArray.from_files( + f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + ), + ], +) +def test_docarray_preserve_original_inputs(make_flow, inputs, port_generator): + c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + r = c.encode(inputs if not callable(inputs) else inputs()) + assert isinstance(r, DocumentArray) + assert r.embeddings.shape + assert r.contents == inputs.contents