From 1db43b485b0fe368eb3949ddc052b5dd8002c279 Mon Sep 17 00:00:00 2001 From: Ziniu Yu Date: Thu, 28 Jul 2022 14:26:58 +0800 Subject: [PATCH] fix: no allow client to change server batch size (#787) --- server/clip_server/executors/clip_onnx.py | 6 ++---- server/clip_server/executors/clip_tensorrt.py | 5 ++--- server/clip_server/executors/clip_torch.py | 5 ++--- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index 51d946bc7..0d73ff48f 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -105,9 +105,7 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @requests async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): - traversal_paths = parameters.get('traversal_paths', self._traversal_paths) - minibatch_size = parameters.get('minibatch_size', self._minibatch_size) _img_da = DocumentArray() _txt_da = DocumentArray() @@ -118,7 +116,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): if _img_da: for minibatch, batch_data in _img_da.map_batch( self._preproc_images, - batch_size=minibatch_size, + batch_size=self._minibatch_size, pool=self._pool, ): with self.monitor( @@ -131,7 +129,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): if _txt_da: for minibatch, batch_data in _txt_da.map_batch( self._preproc_texts, - batch_size=minibatch_size, + batch_size=self._minibatch_size, pool=self._pool, ): with self.monitor( diff --git a/server/clip_server/executors/clip_tensorrt.py b/server/clip_server/executors/clip_tensorrt.py index 4dc369a18..dd2b13542 100644 --- a/server/clip_server/executors/clip_tensorrt.py +++ b/server/clip_server/executors/clip_tensorrt.py @@ -81,7 +81,6 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @requests async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): traversal_paths = parameters.get('traversal_paths', self._traversal_paths) - minibatch_size = parameters.get('minibatch_size', self._minibatch_size) _img_da = DocumentArray() _txt_da = DocumentArray() @@ -92,7 +91,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): if _img_da: for minibatch, batch_data in _img_da.map_batch( self._preproc_images, - batch_size=minibatch_size, + batch_size=self._minibatch_size, pool=self._pool, ): with self.monitor( @@ -111,7 +110,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): if _txt_da: for minibatch, batch_data in _txt_da.map_batch( self._preproc_texts, - batch_size=minibatch_size, + batch_size=self._minibatch_size, pool=self._pool, ): with self.monitor( diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index 6aa9b3cf5..f7e861a84 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -91,7 +91,6 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @requests async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): traversal_paths = parameters.get('traversal_paths', self._traversal_paths) - minibatch_size = parameters.get('minibatch_size', self._minibatch_size) _img_da = DocumentArray() _txt_da = DocumentArray() @@ -103,7 +102,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): if _img_da: for minibatch, batch_data in _img_da.map_batch( self._preproc_images, - batch_size=minibatch_size, + batch_size=self._minibatch_size, pool=self._pool, ): with self.monitor( @@ -121,7 +120,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): if _txt_da: for minibatch, batch_data in _txt_da.map_batch( self._preproc_texts, - batch_size=minibatch_size, + batch_size=self._minibatch_size, pool=self._pool, ): with self.monitor(