Skip to content

Commit

Permalink
feat: add fp16 inference in clip_torch
Browse files Browse the repository at this point in the history
  • Loading branch information
OrangeSodahub committed Dec 4, 2022
1 parent c7af9f7 commit 326e265
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
access_paths: str = '@r',
dtype: Optional[str] = None,
**kwargs,
):
"""
Expand All @@ -40,6 +41,7 @@ def __init__(
number if you encounter OOM errors.
:param access_paths: The access paths to traverse on the input documents to get the images and texts to be
processed. Visit https://docarray.jina.ai/fundamentals/documentarray/access-elements for more details.
:param dtype: inference data type, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
"""
super().__init__(**kwargs)

Expand All @@ -55,6 +57,11 @@ def __init__(
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device
if dtype is None:
dtype = (
'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16'
)
self.dtype = dtype

if not self._device.startswith('cuda') and (
'OMP_NUM_THREADS' not in os.environ
Expand All @@ -77,7 +84,7 @@ def __init__(
self._num_worker_preprocess = num_worker_preprocess
self._pool = ThreadPool(processes=num_worker_preprocess)

self._model = CLIPModel(name, device=self._device, jit=jit, **kwargs)
self._model = CLIPModel(name, device=self._device, jit=jit, dtype=dtype, **kwargs)
self._tokenizer = Tokenizer(name)
self._image_transform = clip._transform_ndarray(self._model.image_size)

Expand Down

0 comments on commit 326e265

Please sign in to comment.