diff --git a/client/clip_client/client.py b/client/clip_client/client.py index 4b91583e7..69056818b 100644 --- a/client/clip_client/client.py +++ b/client/clip_client/client.py @@ -176,17 +176,15 @@ def _iter_doc( _mime = mimetypes.guess_type(c)[0] if _mime and _mime.startswith('image'): d = Document( - tags={'__created_by_CAS__': True, '__loaded_by_CAS__': True}, uri=c, ).load_uri_to_blob() else: - d = Document(tags={'__created_by_CAS__': True}, text=c) + d = Document(text=c) elif isinstance(c, Document): if c.content_type in ('text', 'blob'): d = c elif not c.blob and c.uri: c.load_uri_to_blob() - c.tags['__loaded_by_CAS__'] = True d = c elif c.tensor is not None: d = c @@ -288,8 +286,12 @@ def encode(self, content, **kwargs): results = DocumentArray() with self._pbar: - parameters = kwargs.pop('parameters', None) + parameters = kwargs.pop('parameters', {}) + parameters['drop_image_content'] = parameters.get( + 'drop_image_content', True + ) model_name = parameters.pop('model_name', '') if parameters else '' + self._client.post( on=f'/encode/{model_name}'.rstrip('/'), **self._get_post_payload(content, results, kwargs), @@ -299,10 +301,6 @@ def encode(self, content, **kwargs): parameters=parameters, ) - for r in results: - if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False): - r.pop('blob') - unbox = hasattr(content, '__len__') and isinstance(content[0], str) return self._unboxed_result(results, unbox) @@ -345,7 +343,10 @@ async def aencode(self, content, **kwargs): results = DocumentArray() with self._pbar: - parameters = kwargs.pop('parameters', None) + parameters = kwargs.pop('parameters', {}) + parameters['drop_image_content'] = parameters.get( + 'drop_image_content', True + ) model_name = parameters.get('model_name', '') if parameters else '' async for da in self._async_client.post( @@ -367,10 +368,6 @@ async def aencode(self, content, **kwargs): ), ) - for r in results: - if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False): - r.pop('blob') - unbox = hasattr(content, '__len__') and isinstance(content[0], str) return self._unboxed_result(results, unbox) @@ -423,7 +420,6 @@ def _prepare_single_doc(d: 'Document'): return d elif not d.blob and d.uri: d.load_uri_to_blob() - d.tags['__loaded_by_CAS__'] = True return d elif d.tensor is not None: return d @@ -439,18 +435,6 @@ def _prepare_rank_doc(d: 'Document', _source: str = 'matches'): setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)]) return d - @staticmethod - def _reset_rank_doc(d: 'Document', _source: str = 'matches'): - _get = lambda d: getattr(d, _source) - - if d.tags.pop('__loaded_by_CAS__', False): - d.pop('blob') - - for c in _get(d): - if c.tags.pop('__loaded_by_CAS__', False): - c.pop('blob') - return d - def rank( self, docs: Union['DocumentArray', Iterable['Document']], **kwargs ) -> 'DocumentArray': @@ -474,8 +458,12 @@ def rank( results = DocumentArray() with self._pbar: - parameters = kwargs.pop('parameters', None) + parameters = kwargs.pop('parameters', {}) + parameters['drop_image_content'] = parameters.get( + 'drop_image_content', True + ) model_name = parameters.get('model_name', '') if parameters else '' + self._client.post( on=f'/rank/{model_name}'.rstrip('/'), **self._get_rank_payload(docs, results, kwargs), @@ -485,9 +473,6 @@ def rank( parameters=parameters, ) - for r in results: - self._reset_rank_doc(r, _source=kwargs.get('source', 'matches')) - return results async def arank( @@ -507,8 +492,12 @@ async def arank( results = DocumentArray() with self._pbar: - parameters = kwargs.pop('parameters', None) + parameters = kwargs.pop('parameters', {}) + parameters['drop_image_content'] = parameters.get( + 'drop_image_content', True + ) model_name = parameters.get('model_name', '') if parameters else '' + async for da in self._async_client.post( on=f'/rank/{model_name}'.rstrip('/'), **self._get_rank_payload(docs, results, kwargs), @@ -528,9 +517,6 @@ async def arank( ), ) - for r in results: - self._reset_rank_doc(r, _source=kwargs.get('source', 'matches')) - return results @overload @@ -581,14 +567,21 @@ def index(self, content, **kwargs): raise TypeError( f'content must be an Iterable of [str, Document], try `.index(["{content}"])` instead' ) + if hasattr(content, '__len__') and len(content) == 0: + return DocumentArray() self._prepare_streaming( not kwargs.get('show_progress'), total=len(content) if hasattr(content, '__len__') else None, ) + results = DocumentArray() with self._pbar: - parameters = kwargs.pop('parameters', None) + parameters = kwargs.pop('parameters', {}) + parameters['drop_image_content'] = parameters.get( + 'drop_image_content', True + ) + self._client.post( on='/index', **self._get_post_payload(content, results, kwargs), @@ -598,10 +591,6 @@ def index(self, content, **kwargs): parameters=parameters, ) - for r in results: - if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False): - r.pop('blob') - return results @overload @@ -633,17 +622,25 @@ async def aindex(self, content, **kwargs): raise TypeError( f'content must be an Iterable of [str, Document], try `.aindex(["{content}"])` instead' ) + if hasattr(content, '__len__') and len(content) == 0: + return DocumentArray() self._prepare_streaming( not kwargs.get('show_progress'), total=len(content) if hasattr(content, '__len__') else None, ) + results = DocumentArray() with self._pbar: + parameters = kwargs.pop('parameters', {}) + parameters['drop_image_content'] = parameters.get( + 'drop_image_content', True + ) + async for da in self._async_client.post( on='/index', **self._get_post_payload(content, results, kwargs), - parameters=kwargs.pop('parameters', None), + parameters=parameters, ): results[da[:, 'id']].embeddings = da.embeddings @@ -659,10 +656,6 @@ async def aindex(self, content, **kwargs): ), ) - for r in results: - if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False): - r.pop('blob') - return results @overload @@ -716,15 +709,21 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray': raise TypeError( f'content must be an Iterable of [str, Document], try `.search(["{content}"])` instead' ) + if hasattr(content, '__len__') and len(content) == 0: + return DocumentArray() self._prepare_streaming( not kwargs.get('show_progress'), total=len(content) if hasattr(content, '__len__') else None, ) + results = DocumentArray() with self._pbar: parameters = kwargs.pop('parameters', {}) parameters['limit'] = limit + parameters['drop_image_content'] = parameters.get( + 'drop_image_content', True + ) self._client.post( on='/search', @@ -735,10 +734,6 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray': ), ) - for r in results: - if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False): - r.pop('blob') - return results @overload @@ -772,16 +767,21 @@ async def asearch(self, content, limit: int = 10, **kwargs): raise TypeError( f'content must be an Iterable of [str, Document], try `.asearch(["{content}"])` instead' ) + if hasattr(content, '__len__') and len(content) == 0: + return DocumentArray() self._prepare_streaming( not kwargs.get('show_progress'), total=len(content) if hasattr(content, '__len__') else None, ) - results = DocumentArray() + results = DocumentArray() with self._pbar: parameters = kwargs.pop('parameters', {}) parameters['limit'] = limit + parameters['drop_image_content'] = parameters.get( + 'drop_image_content', True + ) async for da in self._async_client.post( on='/search', @@ -802,8 +802,4 @@ async def asearch(self, content, limit: int = 10, **kwargs): ), ) - for r in results: - if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False): - r.pop('blob') - return results diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index c14da999d..9dff2ff21 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -2,6 +2,7 @@ import warnings from multiprocessing.pool import ThreadPool from typing import Optional, Dict +from functools import partial import onnxruntime as ort from clip_server.executors.helper import ( @@ -99,13 +100,16 @@ def __init__( self._model.start_sessions(sess_options=sess_options, providers=providers) - def _preproc_images(self, docs: 'DocumentArray'): + def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool): with self.monitor( name='preprocess_images_seconds', documentation='images preprocess time in seconds', ): return preproc_image( - docs, preprocess_fn=self._image_transform, return_np=True + docs, + preprocess_fn=self._image_transform, + return_np=True, + drop_image_content=drop_image_content, ) def _preproc_texts(self, docs: 'DocumentArray'): @@ -117,7 +121,8 @@ def _preproc_texts(self, docs: 'DocumentArray'): @requests(on='/rank') async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): - await self.encode(docs['@r,m']) + _drop_image_content = parameters.get('drop_image_content', False) + await self.encode(docs['@r,m'], drop_image_content=_drop_image_content) set_rank(docs) @@ -129,6 +134,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): f'`traversal_paths` is deprecated. Use `access_paths` instead.' ) access_paths = parameters['traversal_paths'] + _drop_image_content = parameters.get('drop_image_content', False) _img_da = DocumentArray() _txt_da = DocumentArray() @@ -138,7 +144,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): # for image if _img_da: for minibatch, batch_data in _img_da.map_batch( - self._preproc_images, + partial(self._preproc_images, drop_image_content=_drop_image_content), batch_size=self._minibatch_size, pool=self._pool, ): diff --git a/server/clip_server/executors/clip_tensorrt.py b/server/clip_server/executors/clip_tensorrt.py index 0f13bd52e..24a9a6f7b 100644 --- a/server/clip_server/executors/clip_tensorrt.py +++ b/server/clip_server/executors/clip_tensorrt.py @@ -1,6 +1,7 @@ import warnings from multiprocessing.pool import ThreadPool from typing import Optional, Dict +from functools import partial import numpy as np from clip_server.executors.helper import ( @@ -67,7 +68,7 @@ def __init__( self._tokenizer = Tokenizer(name) self._image_transform = clip._transform_ndarray(self._model.image_size) - def _preproc_images(self, docs: 'DocumentArray'): + def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool): with self.monitor( name='preprocess_images_seconds', documentation='images preprocess time in seconds', @@ -77,6 +78,7 @@ def _preproc_images(self, docs: 'DocumentArray'): preprocess_fn=self._image_transform, device=self._device, return_np=False, + drop_image_content=drop_image_content, ) def _preproc_texts(self, docs: 'DocumentArray'): @@ -90,7 +92,8 @@ def _preproc_texts(self, docs: 'DocumentArray'): @requests(on='/rank') async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): - await self.encode(docs['@r,m']) + _drop_image_content = parameters.get('drop_image_content', False) + await self.encode(docs['@r,m'], drop_image_content=_drop_image_content) set_rank(docs) @@ -102,6 +105,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): f'`traversal_paths` is deprecated. Use `access_paths` instead.' ) access_paths = parameters['traversal_paths'] + _drop_image_content = parameters.get('drop_image_content', False) _img_da = DocumentArray() _txt_da = DocumentArray() @@ -111,7 +115,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): # for image if _img_da: for minibatch, batch_data in _img_da.map_batch( - self._preproc_images, + partial(self._preproc_images, drop_image_content=_drop_image_content), batch_size=self._minibatch_size, pool=self._pool, ): diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index 64edc8236..d953cebc7 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -2,6 +2,7 @@ import warnings from multiprocessing.pool import ThreadPool from typing import Optional, Dict +from functools import partial import numpy as np import torch @@ -77,7 +78,7 @@ def __init__( self._tokenizer = Tokenizer(name) self._image_transform = clip._transform_ndarray(self._model.image_size) - def _preproc_images(self, docs: 'DocumentArray'): + def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool): with self.monitor( name='preprocess_images_seconds', documentation='images preprocess time in seconds', @@ -87,6 +88,7 @@ def _preproc_images(self, docs: 'DocumentArray'): preprocess_fn=self._image_transform, device=self._device, return_np=False, + drop_image_content=drop_image_content, ) def _preproc_texts(self, docs: 'DocumentArray'): @@ -100,7 +102,8 @@ def _preproc_texts(self, docs: 'DocumentArray'): @requests(on='/rank') async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): - await self.encode(docs['@r,m']) + _drop_image_content = parameters.get('drop_image_content', False) + await self.encode(docs['@r,m'], drop_image_content=_drop_image_content) set_rank(docs) @@ -112,6 +115,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): f'`traversal_paths` is deprecated. Use `access_paths` instead.' ) access_paths = parameters['traversal_paths'] + _drop_image_content = parameters.get('drop_image_content', False) _img_da = DocumentArray() _txt_da = DocumentArray() @@ -122,7 +126,9 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): # for image if _img_da: for minibatch, batch_data in _img_da.map_batch( - self._preproc_images, + partial( + self._preproc_images, drop_image_content=_drop_image_content + ), batch_size=self._minibatch_size, pool=self._pool, ): diff --git a/server/clip_server/executors/helper.py b/server/clip_server/executors/helper.py index bfe852d7c..c7ac6a555 100644 --- a/server/clip_server/executors/helper.py +++ b/server/clip_server/executors/helper.py @@ -21,6 +21,7 @@ def preproc_image( preprocess_fn: Callable, device: str = 'cpu', return_np: bool = False, + drop_image_content: bool = False, ) -> Tuple['DocumentArray', Dict]: tensors_batch = [] @@ -37,10 +38,9 @@ def preproc_image( tensors_batch.append(preprocess_fn(d.tensor).detach()) # recover doc content - if d.tags.pop('__loaded_by_CAS__', False): - d.pop('tensor') - else: - d.content = content + d.content = content + if drop_image_content: + d.pop('blob', 'tensor') tensors_batch = torch.stack(tensors_batch).type(torch.float32) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 0501a98f8..cf3618be4 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -43,12 +43,10 @@ async def test_async_docarray_preserve_original_inputs(make_flow, inputs): t2 = asyncio.create_task(c.aencode(inputs if not callable(inputs) else inputs())) await asyncio.gather(t1, t2) assert isinstance(t2.result(), DocumentArray) + assert inputs[0] is t2.result()[0] assert t2.result().embeddings.shape assert t2.result().contents == inputs.contents - assert '__created_by_CAS__' not in t2.result()[0].tags - assert '__loaded_by_CAS__' not in t2.result()[0].tags assert not t2.result()[0].tensor - assert not t2.result()[0].blob assert inputs[0] is t2.result()[0] diff --git a/tests/test_client.py b/tests/test_client.py index c15563924..c3bf4511b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -108,6 +108,22 @@ async def test_client_empty_input(make_torch_flow, inputs): assert isinstance(r, list) assert len(r) == 0 + r = c.index(inputs if not callable(inputs) else inputs()) + assert isinstance(r, DocumentArray) + assert len(r) == 0 + + r = await c.aindex(inputs if not callable(inputs) else inputs()) + assert isinstance(r, DocumentArray) + assert len(r) == 0 + + r = c.search(inputs if not callable(inputs) else inputs()) + assert isinstance(r, DocumentArray) + assert len(r) == 0 + + r = await c.asearch(inputs if not callable(inputs) else inputs()) + assert isinstance(r, DocumentArray) + assert len(r) == 0 + @pytest.mark.asyncio async def test_wrong_input_type(make_torch_flow): diff --git a/tests/test_helper.py b/tests/test_helper.py index 49cbfe88f..acce92ad2 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -85,10 +85,6 @@ def test_split_img_txt_da(inputs): [ DocumentArray( [ - Document( - uri='https://docarray.jina.ai/_static/favicon.png', - tags={'__loaded_by_CAS__': True}, - ).load_uri_to_blob(), Document( uri='https://docarray.jina.ai/_static/favicon.png', ).load_uri_to_blob(), @@ -100,10 +96,8 @@ def test_preproc_image(inputs): from clip_server.model import clip preprocess_fn = clip._transform_ndarray(224) - da, pixel_values = preproc_image(inputs, preprocess_fn) - assert len(da) == 2 + da, pixel_values = preproc_image(inputs, preprocess_fn, drop_image_content=True) + assert len(da) == 1 assert not da[0].blob - assert da[1].blob assert not da[0].tensor - assert not da[1].tensor assert pixel_values.get('pixel_values') is not None diff --git a/tests/test_ranker.py b/tests/test_ranker.py index 60662af1b..70f8165b3 100644 --- a/tests/test_ranker.py +++ b/tests/test_ranker.py @@ -30,14 +30,10 @@ async def test_torch_executor_rank_img2texts(encoder_class): for d in da: for c in d.matches: assert c.scores['clip_score'].value is not None - assert '__loaded_by_CAS__' not in c.tags assert not c.tensor - assert not c.blob org_score = d.matches[:, 'scores__clip_score__value'] assert org_score == list(sorted(org_score, reverse=True)) - assert '__loaded_by_CAS__' not in d.tags assert not d.tensor - assert not d.blob @pytest.mark.asyncio @@ -59,13 +55,10 @@ async def test_torch_executor_rank_text2imgs(encoder_class): for c in d.matches: assert c.scores['clip_score'].value is not None assert c.scores['clip_score_cosine'].value is not None - assert '__loaded_by_CAS__' not in c.tags assert not c.tensor - assert not c.blob np.testing.assert_almost_equal( sum(c.scores['clip_score'].value for c in d.matches), 1 ) - assert '__loaded_by_CAS__' not in d.tags assert not d.tensor assert not d.blob @@ -135,8 +128,6 @@ async def test_torch_executor_rank_text2imgs(encoder_class): def test_docarray_inputs(make_flow, inputs): c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') r = c.rank(inputs if not callable(inputs) else inputs()) - assert '__loaded_by_CAS__' not in r[0].tags - assert not r[0].blob assert not r[0].tensor assert isinstance(r, DocumentArray) rv1 = r['@m', 'scores__clip_score__value'] @@ -200,8 +191,6 @@ def test_docarray_inputs(make_flow, inputs): async def test_async_arank(make_flow, inputs): c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') r = await c.arank(inputs if not callable(inputs) else inputs()) - assert '__loaded_by_CAS__' not in r[0].tags - assert not r[0].blob assert not r[0].tensor assert isinstance(r, DocumentArray) rv = r['@m', 'scores__clip_score__value'] diff --git a/tests/test_simple.py b/tests/test_simple.py index 5648aa8be..dcc8ef2c6 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -76,10 +76,7 @@ def test_docarray_inputs(make_flow, inputs): r = c.encode(inputs if not callable(inputs) else inputs()) assert isinstance(r, DocumentArray) assert r.embeddings.shape - assert '__created_by_CAS__' not in r[0].tags - assert '__loaded_by_CAS__' not in r[0].tags assert not r[0].tensor - assert not r[0].blob if hasattr(inputs, '__len__'): assert inputs[0] is r[0] @@ -107,10 +104,7 @@ def test_docarray_preserve_original_inputs(make_flow, inputs): assert isinstance(r, DocumentArray) assert r.embeddings.shape assert r.contents == inputs.contents - assert '__created_by_CAS__' not in r[0].tags - assert '__loaded_by_CAS__' not in r[0].tags assert not r[0].tensor - assert not r[0].blob assert inputs[0] is r[0] @@ -141,8 +135,6 @@ def test_docarray_traversal(make_flow, inputs): r1 = c.post(on='/', inputs=da, parameters={'traversal_paths': '@c'}) assert isinstance(r1, DocumentArray) assert r1[0].chunks.embeddings.shape[0] == len(inputs) - assert '__created_by_CAS__' not in r1[0].tags - assert '__loaded_by_CAS__' not in r1[0].tags assert not r1[0].tensor assert not r1[0].blob assert not r1[0].chunks[0].tensor @@ -151,8 +143,6 @@ def test_docarray_traversal(make_flow, inputs): r2 = c.post(on='/', inputs=da, parameters={'access_paths': '@c'}) assert isinstance(r2, DocumentArray) assert r2[0].chunks.embeddings.shape[0] == len(inputs) - assert '__created_by_CAS__' not in r2[0].tags - assert '__loaded_by_CAS__' not in r2[0].tags assert not r2[0].tensor assert not r2[0].blob assert not r2[0].chunks[0].tensor