Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fp16 inference support (torch/onnx) #871

Merged
merged 40 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
326e265
feat: add fp16 inference in clip_torch
OrangeSodahub Dec 4, 2022
b0bca12
Revert "feat: add fp16 inference in clip_torch"
OrangeSodahub Dec 4, 2022
2e09165
feat: add fp16 inference in clip_torch
OrangeSodahub Dec 4, 2022
6191813
fix: device
OrangeSodahub Dec 4, 2022
a15bd61
fix: str to torch.dtype
OrangeSodahub Dec 4, 2022
71e170a
fix: layernorm
OrangeSodahub Dec 4, 2022
28f6622
feat: add fp16 inference in clip_trt
OrangeSodahub Dec 4, 2022
3477fae
feat: add fp16 inference in clip_onnx
OrangeSodahub Dec 4, 2022
a5fff1d
fix: housekeeping
OrangeSodahub Dec 5, 2022
4f6025c
fix: ci
OrangeSodahub Dec 5, 2022
43bf259
fix: ci
OrangeSodahub Dec 6, 2022
12ae920
fix: ci
OrangeSodahub Dec 6, 2022
b559109
fix: ci and get test path
OrangeSodahub Dec 6, 2022
8cd0bb0
fix: dtype amp and gpu test dependency
OrangeSodahub Dec 6, 2022
4d891f0
fix: layernorm
OrangeSodahub Dec 6, 2022
a193381
fix: cast dtype in visiontransformer
OrangeSodahub Dec 6, 2022
c502679
fix: clip_onnx
OrangeSodahub Dec 6, 2022
3c14f5a
fix: clip_onnx
OrangeSodahub Dec 6, 2022
6b56623
fix: convert onnx to fp16
OrangeSodahub Dec 6, 2022
9984b4b
fix: dtype in preproc images
OrangeSodahub Dec 6, 2022
2f8ed28
fix: dtype in preproc images
OrangeSodahub Dec 6, 2022
d25a396
fix: typo
OrangeSodahub Dec 6, 2022
f0f8b43
fix: dtype in clip_torch and fp16 in trt
OrangeSodahub Dec 6, 2022
d47d2c9
fix: remove plain text in trt_test
OrangeSodahub Dec 6, 2022
d3b2ff6
fix: test
OrangeSodahub Dec 6, 2022
e597565
fix: typo
OrangeSodahub Dec 6, 2022
f72fd99
fix: stash
OrangeSodahub Dec 6, 2022
cb271fb
Revert "fix: stash"
OrangeSodahub Dec 6, 2022
d199efc
fix: for test
OrangeSodahub Dec 7, 2022
e96c31e
fix: onnx
OrangeSodahub Dec 7, 2022
9df79a9
fix: for test
OrangeSodahub Dec 7, 2022
9b8f60c
fix: for test
OrangeSodahub Dec 7, 2022
4779780
fix: trt
OrangeSodahub Dec 7, 2022
35defc9
fix: convert onnx to fp16 before convert trt
OrangeSodahub Dec 7, 2022
20bbaa1
Merge branch 'main' into add-fp16-inference-support
OrangeSodahub Dec 7, 2022
6d221cd
fix: discard changes in trt
OrangeSodahub Dec 8, 2022
fc36e9b
fix: optimize fp16 test
OrangeSodahub Dec 8, 2022
edf4629
fix: move __cast_dtype__
OrangeSodahub Dec 8, 2022
aa574d0
Revert "fix: move __cast_dtype__"
OrangeSodahub Dec 8, 2022
bd1fe7c
fix: ci
OrangeSodahub Dec 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install --no-cache-dir "server/[transformers]"
- name: Test
id: test
run: |
Expand Down Expand Up @@ -158,6 +157,7 @@ jobs:
python -m pip install wheel pytest pytest-cov nvidia-pyindex
pip install -e "client/[test]"
pip install -e "server/[tensorrt]"
pip install -e "server/[onnx]"
{
pip install -e "server/[flash-attn]"
} || {
Expand All @@ -168,6 +168,8 @@ jobs:
run: |
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
-v -s -m "gpu" ./tests/test_tensorrt.py
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
-v -s -m "gpu" ./tests/test_fp16.py
echo "::set-output name=codecov_flag::cas"
timeout-minutes: 30
env:
Expand Down
2 changes: 1 addition & 1 deletion scripts/get-all-test-paths.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ BATCH_SIZE=3
#declare -a array1=( "tests/unit/test_*.py" )
#declare -a array2=( $(ls -d tests/unit/*/ | grep -v '__pycache__' | grep -v 'array') )
#declare -a array3=( "tests/unit/array/*.py" )
declare -a mixins=( $(find tests -name "test_*.py" | grep -v 'test_tensorrt.py') )
declare -a mixins=( $(find tests -name "test_*.py" | grep -v 'test_tensorrt.py' | grep -v 'test_fp16.py') )
declare -a array4=( "$(echo "${mixins[@]}" | xargs -n$BATCH_SIZE)" )
# array5 is currently empty because in the array/ directory, mixins is the only directory
# but add the following in case new directories are created in array/
Expand Down
20 changes: 12 additions & 8 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
minibatch_size: int = 32,
access_paths: str = '@r',
model_path: Optional[str] = None,
dtype: Optional[str] = None,
**kwargs,
):
"""
Expand All @@ -41,8 +42,17 @@ def __init__(
:param model_path: The path to the model to be used. If not specified, the model will be downloaded or loaded
from the local cache. Visit https://clip-as-service.jina.ai/user-guides/server/#use-custom-model-for-onnx
to learn how to finetune custom models.
:param dtype: inference data type, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
"""
super().__init__(**kwargs)
import torch

if not device:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self._device = device
if not dtype:
dtype = 'fp32' if self._device in ('cpu', torch.device('cpu')) else 'fp16'
self._dtype = dtype

self._minibatch_size = minibatch_size
self._access_paths = access_paths
Expand All @@ -55,18 +65,11 @@ def __init__(
self._num_worker_preprocess = num_worker_preprocess
self._pool = ThreadPool(processes=num_worker_preprocess)

self._model = CLIPOnnxModel(name, model_path)
self._model = CLIPOnnxModel(name, model_path, dtype)
self._tokenizer = Tokenizer(name)

self._image_transform = clip._transform_ndarray(self._model.image_size)

import torch

if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device

# define the priority order for the execution providers
providers = ['CPUExecutionProvider']

Expand Down Expand Up @@ -116,6 +119,7 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
preprocess_fn=self._image_transform,
return_np=True,
drop_image_content=drop_image_content,
dtype=self._dtype,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand Down
24 changes: 19 additions & 5 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from functools import partial
from multiprocessing.pool import ThreadPool
from typing import Dict, Optional
from typing import Dict, Union, Optional

import numpy as np
import torch
Expand All @@ -12,6 +12,7 @@
set_rank,
split_img_txt_da,
)
from clip_server.helper import __cast_dtype__
from clip_server.model import clip
from clip_server.model.clip_model import CLIPModel
from clip_server.model.tokenization import Tokenizer
Expand All @@ -28,6 +29,7 @@ def __init__(
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
access_paths: str = '@r',
dtype: Optional[Union[str, torch.dtype]] = None,
**kwargs,
):
"""
Expand All @@ -40,6 +42,7 @@ def __init__(
number if you encounter OOM errors.
:param access_paths: The access paths to traverse on the input documents to get the images and texts to be
processed. Visit https://docarray.jina.ai/fundamentals/documentarray/access-elements for more details.
:param dtype: inference data type, if None defaults to torch.float32 if device == 'cpu' else torch.float16.
"""
super().__init__(**kwargs)

Expand All @@ -52,9 +55,17 @@ def __init__(
self._access_paths = kwargs['traversal_paths']

if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self._device = device
if isinstance(dtype, str):
dtype = __cast_dtype__.get(dtype)
elif not dtype:
dtype = (
torch.float32
if self._device in ('cpu', torch.device('cpu'))
else torch.float16
)
self._dtype = dtype

if not self._device.startswith('cuda') and (
'OMP_NUM_THREADS' not in os.environ
Expand All @@ -77,7 +88,9 @@ def __init__(
self._num_worker_preprocess = num_worker_preprocess
self._pool = ThreadPool(processes=num_worker_preprocess)

self._model = CLIPModel(name, device=self._device, jit=jit, **kwargs)
self._model = CLIPModel(
name, device=self._device, jit=jit, dtype=dtype, **kwargs
)
self._tokenizer = Tokenizer(name)
self._image_transform = clip._transform_ndarray(self._model.image_size)

Expand All @@ -96,6 +109,7 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
device=self._device,
return_np=False,
drop_image_content=drop_image_content,
dtype=self._dtype,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand Down
9 changes: 7 additions & 2 deletions server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Tuple, List, Callable, Any, Dict
from typing import Tuple, List, Callable, Any, Dict, Union
import torch
import numpy as np
from docarray import Document, DocumentArray
from docarray.math.distance.numpy import cosine
from clip_server.helper import __cast_dtype__


from clip_server.model.tokenization import Tokenizer
Expand All @@ -22,8 +23,12 @@ def preproc_image(
device: str = 'cpu',
return_np: bool = False,
drop_image_content: bool = False,
dtype: Union[str, torch.dtype] = torch.float32,
) -> Tuple['DocumentArray', Dict]:

if isinstance(dtype, str):
dtype = __cast_dtype__.get(dtype)

tensors_batch = []

for d in da:
Expand All @@ -42,7 +47,7 @@ def preproc_image(
if drop_image_content:
d.pop('blob', 'tensor')

tensors_batch = torch.stack(tensors_batch).type(torch.float32)
tensors_batch = torch.stack(tensors_batch).type(dtype)

if return_np:
tensors_batch = tensors_batch.cpu().numpy()
Expand Down
4 changes: 4 additions & 0 deletions server/clip_server/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import threading
import torch
from packaging.version import Version
from urllib.request import Request, urlopen

Expand All @@ -19,6 +20,9 @@
)


__cast_dtype__ = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16}


OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
def _version_check(package: str = None, github_repo: str = None):
try:

Expand Down
23 changes: 21 additions & 2 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict
from typing import Dict, Optional

from clip_server.model.pretrained_models import (
download_model,
Expand Down Expand Up @@ -201,8 +201,11 @@


class CLIPOnnxModel(BaseCLIPModel):
def __init__(self, name: str, model_path: str = None):
def __init__(
self, name: str, model_path: str = None, dtype: Optional[str] = 'fp32'
):
super().__init__(name)
self._dtype = dtype
if name in _MODELS:
if not model_path:
cache_dir = os.path.expanduser(
Expand Down Expand Up @@ -237,6 +240,22 @@ def __init__(self, name: str, model_path: str = None):
f'The given model path {model_path} should be a folder containing both '
f'`textual.onnx` and `visual.onnx`.'
)
if dtype == 'fp16':
import onnx
from onnxmltools.utils import float16_converter

_textual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._textual_path
)
)
_visual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._visual_path
)
)
onnx.save_model(_textual_model_fp16, self._textual_path)
onnx.save_model(_visual_model_fp16, self._visual_path)
else:
raise RuntimeError(
'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format(
Expand Down
14 changes: 12 additions & 2 deletions server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass
from typing import Tuple, Union, Optional
from copy import deepcopy
from clip_server.helper import __cast_dtype__
from open_clip.transformer import QuickGELU, LayerNorm, LayerNormFp32, Attention
from open_clip.timm_model import TimmModel
from open_clip.factory import _MODEL_CONFIGS
Expand Down Expand Up @@ -81,6 +82,11 @@ def __init__(
super().__init__(image_size, patch_size, output_dim=output_dim, **kwargs)
self.transformer = Transformer(dtype=dtype, **kwargs)

def forward(self, x: torch.Tensor):
dtype = self.transformer.get_cast_dtype()
x = x.to(dtype)
return super().forward(x)


class TextTransformer(_TextTransformer):
def __init__(
Expand Down Expand Up @@ -435,7 +441,9 @@ def load_openai_model(
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if dtype is None:
if isinstance(dtype, str):
dtype = __cast_dtype__.get(dtype, 'amp')
elif dtype is None:
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
dtype = (
torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16
)
Expand Down Expand Up @@ -550,7 +558,9 @@ def load_openclip_model(
pretrained_image: bool = False,
dtype: Optional[Union[str, torch.dtype]] = None,
):
if dtype is None:
if isinstance(dtype, str):
dtype = __cast_dtype__.get(dtype)
elif dtype is None:
dtype = (
torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16
)
Expand Down
1 change: 1 addition & 0 deletions server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
'onnx': [
'onnxruntime',
'onnx',
'onnxmltools',
]
+ (['onnxruntime-gpu>=1.8.0'] if sys.platform != 'darwin' else []),
'tensorrt': ['nvidia-tensorrt'],
Expand Down
63 changes: 63 additions & 0 deletions tests/test_fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

import pytest
from docarray import Document, DocumentArray
from jina import Flow

from clip_client.client import Client


@pytest.mark.gpu
@pytest.mark.parametrize(
'inputs',
[
['hello, world', 'goodbye, world'],
('hello, world', 'goodbye, world'),
lambda: ('hello, world' for _ in range(10)),
[
'https://docarray.jina.ai/_static/favicon.png',
f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg',
'hello, world',
],
],
)
def test_plain_inputs(make_flow, inputs):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
r = c.encode(inputs if not callable(inputs) else inputs())
assert (
r.shape[0] == len(list(inputs)) if not callable(inputs) else len(list(inputs()))
)


@pytest.mark.gpu
@pytest.mark.parametrize(
'inputs',
[
[Document(text='hello, world'), Document(text='goodbye, world')],
DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]),
lambda: (Document(text='hello, world') for _ in range(10)),
DocumentArray(
[
Document(uri='https://docarray.jina.ai/_static/favicon.png'),
Document(
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
),
Document(text='hello, world'),
Document(
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
).load_uri_to_image_tensor(),
]
),
DocumentArray.from_files(
f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
),
],
)
def test_docarray_inputs(make_flow, inputs):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
r = c.encode(inputs if not callable(inputs) else inputs())
jemmyshin marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(r, DocumentArray)
assert r.embeddings.shape
assert not r[0].tensor
if hasattr(inputs, '__len__'):
assert inputs[0] is r[0]