From 3402b1d1726120d8ed39ae561e441695f24ddeb3 Mon Sep 17 00:00:00 2001 From: Ziniu Yu Date: Wed, 3 Aug 2022 10:10:53 +0800 Subject: [PATCH] feat: replace traversal_paths with access_paths (#791) * feat: adapt access_paths deprecate traversal_paths * fix: put traversal_paths in kwargs * fix: remove unused param --- server/clip_server/executors/clip_onnx.py | 18 ++++++++++++---- server/clip_server/executors/clip_tensorrt.py | 21 ++++++++++++++----- server/clip_server/executors/clip_torch.py | 18 ++++++++++++---- tests/test_simple.py | 9 +++++--- 4 files changed, 50 insertions(+), 16 deletions(-) diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index 0d73ff48f..0a6a3e6d4 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -23,14 +23,19 @@ def __init__( device: Optional[str] = None, num_worker_preprocess: int = 4, minibatch_size: int = 32, - traversal_paths: str = '@r', + access_paths: str = '@r', model_path: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) self._minibatch_size = minibatch_size - self._traversal_paths = traversal_paths + self._access_paths = access_paths + if 'traversal_paths' in kwargs: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + self._access_paths = kwargs['traversal_paths'] self._pool = ThreadPool(processes=num_worker_preprocess) @@ -105,11 +110,16 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @requests async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): - traversal_paths = parameters.get('traversal_paths', self._traversal_paths) + access_paths = parameters.get('access_paths', self._access_paths) + if 'traversal_paths' in parameters: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + access_paths = parameters['traversal_paths'] _img_da = DocumentArray() _txt_da = DocumentArray() - for d in docs[traversal_paths]: + for d in docs[access_paths]: split_img_txt_da(d, _img_da, _txt_da) # for image diff --git a/server/clip_server/executors/clip_tensorrt.py b/server/clip_server/executors/clip_tensorrt.py index dd2b13542..60eaa50a6 100644 --- a/server/clip_server/executors/clip_tensorrt.py +++ b/server/clip_server/executors/clip_tensorrt.py @@ -1,5 +1,6 @@ +import warnings from multiprocessing.pool import ThreadPool -from typing import Dict +from typing import Optional, Dict import numpy as np from clip_server.executors.helper import ( @@ -21,7 +22,7 @@ def __init__( device: str = 'cuda', num_worker_preprocess: int = 4, minibatch_size: int = 32, - traversal_paths: str = '@r', + access_paths: str = '@r', **kwargs, ): super().__init__(**kwargs) @@ -29,7 +30,12 @@ def __init__( self._pool = ThreadPool(processes=num_worker_preprocess) self._minibatch_size = minibatch_size - self._traversal_paths = traversal_paths + self._access_paths = access_paths + if 'traversal_paths' in kwargs: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + self._access_paths = kwargs['traversal_paths'] self._device = device @@ -80,11 +86,16 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @requests async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): - traversal_paths = parameters.get('traversal_paths', self._traversal_paths) + access_paths = parameters.get('access_paths', self._access_paths) + if 'traversal_paths' in parameters: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + access_paths = parameters['traversal_paths'] _img_da = DocumentArray() _txt_da = DocumentArray() - for d in docs[traversal_paths]: + for d in docs[access_paths]: split_img_txt_da(d, _img_da, _txt_da) # for image diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index f7e861a84..5aab4f5f0 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -25,13 +25,18 @@ def __init__( jit: bool = False, num_worker_preprocess: int = 4, minibatch_size: int = 32, - traversal_paths: str = '@r', + access_paths: str = '@r', **kwargs, ): super().__init__(**kwargs) self._minibatch_size = minibatch_size - self._traversal_paths = traversal_paths + self._access_paths = access_paths + if 'traversal_paths' in kwargs: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + self._access_paths = kwargs['traversal_paths'] if not device: self._device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -90,11 +95,16 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @requests async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): - traversal_paths = parameters.get('traversal_paths', self._traversal_paths) + access_paths = parameters.get('access_paths', self._access_paths) + if 'traversal_paths' in parameters: + warnings.warn( + f'`traversal_paths` is deprecated. Use `access_paths` instead.' + ) + access_paths = parameters['traversal_paths'] _img_da = DocumentArray() _txt_da = DocumentArray() - for d in docs[traversal_paths]: + for d in docs[access_paths]: split_img_txt_da(d, _img_da, _txt_da) with torch.inference_mode(): diff --git a/tests/test_simple.py b/tests/test_simple.py index 027df96e3..0f19bacc9 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -130,6 +130,9 @@ def test_docarray_traversal(make_flow, inputs, port_generator): from jina import Client as _Client c = _Client(host=f'grpc://0.0.0.0', port=make_flow.port) - r = c.post(on='/', inputs=da, parameters={'traversal_paths': '@c'}) - assert r[0].chunks.embeddings.shape[0] == len(inputs) - assert '__created_by_CAS__' not in r[0].tags + r1 = c.post(on='/', inputs=da, parameters={'traversal_paths': '@c'}) + r2 = c.post(on='/', inputs=da, parameters={'access_paths': '@c'}) + assert r1[0].chunks.embeddings.shape[0] == len(inputs) + assert '__created_by_CAS__' not in r1[0].tags + assert r2[0].chunks.embeddings.shape[0] == len(inputs) + assert '__created_by_CAS__' not in r2[0].tags