diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index 0d73ff48f..bf9ae6da4 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -23,14 +23,20 @@ def __init__( device: Optional[str] = None, num_worker_preprocess: int = 4, minibatch_size: int = 32, - traversal_paths: str = '@r', + access_paths: str = '@r', + traversal_paths: Optional[str] = '@r', model_path: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) self._minibatch_size = minibatch_size - self._traversal_paths = traversal_paths + self._access_paths = access_paths + if traversal_paths is not None: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + self._access_paths = traversal_paths self._pool = ThreadPool(processes=num_worker_preprocess) @@ -105,11 +111,16 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @requests async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): - traversal_paths = parameters.get('traversal_paths', self._traversal_paths) + access_paths = parameters.get('access_paths', self._access_paths) + if 'traversal_paths' in parameters: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + access_paths = parameters['traversal_paths'] _img_da = DocumentArray() _txt_da = DocumentArray() - for d in docs[traversal_paths]: + for d in docs[access_paths]: split_img_txt_da(d, _img_da, _txt_da) # for image diff --git a/server/clip_server/executors/clip_tensorrt.py b/server/clip_server/executors/clip_tensorrt.py index dd2b13542..440294c84 100644 --- a/server/clip_server/executors/clip_tensorrt.py +++ b/server/clip_server/executors/clip_tensorrt.py @@ -1,5 +1,6 @@ +import warnings from multiprocessing.pool import ThreadPool -from typing import Dict +from typing import Optional, Dict import numpy as np from clip_server.executors.helper import ( @@ -21,7 +22,8 @@ def __init__( device: str = 'cuda', num_worker_preprocess: int = 4, minibatch_size: int = 32, - traversal_paths: str = '@r', + access_paths: str = '@r', + traversal_paths: Optional[str] = '@r', **kwargs, ): super().__init__(**kwargs) @@ -29,7 +31,12 @@ def __init__( self._pool = ThreadPool(processes=num_worker_preprocess) self._minibatch_size = minibatch_size - self._traversal_paths = traversal_paths + self._access_paths = access_paths + if traversal_paths is not None: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + self._access_paths = traversal_paths self._device = device @@ -80,11 +87,16 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @requests async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): - traversal_paths = parameters.get('traversal_paths', self._traversal_paths) + access_paths = parameters.get('access_paths', self._access_paths) + if 'traversal_paths' in parameters: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + access_paths = parameters['traversal_paths'] _img_da = DocumentArray() _txt_da = DocumentArray() - for d in docs[traversal_paths]: + for d in docs[access_paths]: split_img_txt_da(d, _img_da, _txt_da) # for image diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index f7e861a84..224b22426 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -25,13 +25,19 @@ def __init__( jit: bool = False, num_worker_preprocess: int = 4, minibatch_size: int = 32, - traversal_paths: str = '@r', + access_paths: str = '@r', + traversal_paths: Optional[str] = '@r', **kwargs, ): super().__init__(**kwargs) self._minibatch_size = minibatch_size - self._traversal_paths = traversal_paths + self._access_paths = access_paths + if traversal_paths is not None: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + self._access_paths = traversal_paths if not device: self._device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -90,11 +96,16 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @requests async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): - traversal_paths = parameters.get('traversal_paths', self._traversal_paths) + access_paths = parameters.get('access_paths', self._access_paths) + if 'traversal_paths' in parameters: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + access_paths = parameters['traversal_paths'] _img_da = DocumentArray() _txt_da = DocumentArray() - for d in docs[traversal_paths]: + for d in docs[access_paths]: split_img_txt_da(d, _img_da, _txt_da) with torch.inference_mode(): diff --git a/tests/test_simple.py b/tests/test_simple.py index 027df96e3..0f19bacc9 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -130,6 +130,9 @@ def test_docarray_traversal(make_flow, inputs, port_generator): from jina import Client as _Client c = _Client(host=f'grpc://0.0.0.0', port=make_flow.port) - r = c.post(on='/', inputs=da, parameters={'traversal_paths': '@c'}) - assert r[0].chunks.embeddings.shape[0] == len(inputs) - assert '__created_by_CAS__' not in r[0].tags + r1 = c.post(on='/', inputs=da, parameters={'traversal_paths': '@c'}) + r2 = c.post(on='/', inputs=da, parameters={'access_paths': '@c'}) + assert r1[0].chunks.embeddings.shape[0] == len(inputs) + assert '__created_by_CAS__' not in r1[0].tags + assert r2[0].chunks.embeddings.shape[0] == len(inputs) + assert '__created_by_CAS__' not in r2[0].tags