Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support custom onnx file and update model signatures #761

Merged
merged 52 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
7fbb041
feat: allow custom onnx file
ZiniuYu Jun 27, 2022
10dd4f5
fix: path name
ZiniuYu Jun 27, 2022
bcad79c
fix: validate model path
ZiniuYu Jun 28, 2022
1e8ea3b
chore: improve error message
ZiniuYu Jun 28, 2022
9039b8f
test: add custom path unit test
ZiniuYu Jun 28, 2022
5c874a1
test: add test cases
ZiniuYu Jun 28, 2022
6413f74
test: add test cases
ZiniuYu Jun 28, 2022
8cff7ba
test: add test cases
ZiniuYu Jun 28, 2022
7affea9
fix: reindent
ZiniuYu Jun 28, 2022
fc3a41f
fix: change type to int32
ZiniuYu Jun 30, 2022
5592455
fix: modify text input
ZiniuYu Jul 1, 2022
14da62b
chore: format code
ZiniuYu Jul 1, 2022
1dd7589
chore: update model links
ZiniuYu Jul 1, 2022
6802d25
fix: update links
ZiniuYu Jul 1, 2022
6242203
fix: typo
ZiniuYu Jul 1, 2022
3b9e917
fix: add attention mask for onnx
ZiniuYu Jul 1, 2022
e704242
fix: trt text encode key
ZiniuYu Jul 4, 2022
1025ffd
fix: fix trt shape
ZiniuYu Jul 4, 2022
55a18ca
fix: trt convert
ZiniuYu Jul 5, 2022
4ac14b8
fix: trt convert
ZiniuYu Jul 5, 2022
f748e26
fix: tensorrt parse model
ZiniuYu Jul 5, 2022
22d8a90
fix: add md5 verification
ZiniuYu Jul 5, 2022
5b0a251
fix: add md5 verification
ZiniuYu Jul 6, 2022
9ab65cc
feat: add md5 validation
ZiniuYu Jul 6, 2022
3174c8e
feat: add torch md5
ZiniuYu Jul 6, 2022
bb65ffc
feat: add torch md5
ZiniuYu Jul 6, 2022
1e7bb3f
feat: add onnx md5
ZiniuYu Jul 6, 2022
3380e89
fix: md5 validation
ZiniuYu Jul 6, 2022
0f80a95
chore: clean up
ZiniuYu Jul 6, 2022
33217d6
fix: typo
ZiniuYu Jul 6, 2022
66b49f7
fix: typo
ZiniuYu Jul 7, 2022
aaf25e0
fix: typo
ZiniuYu Jul 7, 2022
43ba806
fix: correct path
ZiniuYu Jul 7, 2022
7df5b2b
fix: trt path
ZiniuYu Jul 7, 2022
916ab5c
test: add md5 test
ZiniuYu Jul 7, 2022
8ab7e89
test: add path test
ZiniuYu Jul 7, 2022
d20785b
fix: house keeping
ZiniuYu Jul 7, 2022
3bd3b5d
fix: house keeping
ZiniuYu Jul 7, 2022
f3a059e
fix: house keeping
ZiniuYu Jul 7, 2022
21c857c
fix: md5 test case
ZiniuYu Jul 7, 2022
8d63465
fix: modify visual signature
ZiniuYu Jul 7, 2022
7868308
fix: modify visual signature
ZiniuYu Jul 7, 2022
25c8fdc
fix: improve download retry
ZiniuYu Jul 7, 2022
4e02508
fix: trt timeout 30 min
ZiniuYu Jul 7, 2022
dfcd82c
fix: modify download logic
ZiniuYu Jul 10, 2022
6bea114
docs: update trt
ZiniuYu Jul 10, 2022
cd631c5
fix: validation
ZiniuYu Jul 11, 2022
e63cb46
fix: polish download with md5
numb3r3 Jul 11, 2022
45fb0f6
fix: polish download with md5
numb3r3 Jul 11, 2022
f6facd0
fix: stop with max retires
numb3r3 Jul 11, 2022
56d0dd3
fix: use forloop
numb3r3 Jul 11, 2022
94008ac
test: none regular file
ZiniuYu Jul 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
traversal_paths: str = '@r',
model_path: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -33,7 +34,7 @@ def __init__(
self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._pool = ThreadPool(processes=num_worker_preprocess)

self._model = CLIPOnnxModel(name)
self._model = CLIPOnnxModel(name, model_path)

import torch

Expand Down
28 changes: 20 additions & 8 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,27 @@


class CLIPOnnxModel:
def __init__(self, name: str = None):
def __init__(self, name: str = None, model_path: str = None):
if name in _MODELS:
cache_dir = os.path.expanduser(f'~/.cache/clip/{name.replace("/", "-")}')
self._textual_path = _download(
_S3_BUCKET + _MODELS[name][0], cache_dir, with_resume=True
)
self._visual_path = _download(
_S3_BUCKET + _MODELS[name][1], cache_dir, with_resume=True
)
if not model_path:
cache_dir = os.path.expanduser(
f'~/.cache/clip/{name.replace("/", "-")}'
)
self._textual_path = _download(
_S3_BUCKET + _MODELS[name][0], cache_dir, with_resume=True
)
self._visual_path = _download(
_S3_BUCKET + _MODELS[name][1], cache_dir, with_resume=True
)
elif os.path.isdir(model_path):
ZiniuYu marked this conversation as resolved.
Show resolved Hide resolved
self._textual_path = os.path.join(model_path, 'textual.onnx')
self._visual_path = os.path.join(model_path, 'visual.onnx')
if not os.path.isfile(self._textual_path) or not os.path.isfile(
self._visual_path
):
raise RuntimeError(
f'{model_path} does not contain `textual.onnx` and `visual.onnx`'
)
ZiniuYu marked this conversation as resolved.
Show resolved Hide resolved
else:
raise RuntimeError(
f'Model {name} not found; available models = {available_models()}'
Expand Down
24 changes: 17 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,26 @@ def random_port():
return random_port


@pytest.fixture(scope='session', params=['onnx', 'torch', 'hg'])
@pytest.fixture(scope='session', params=['onnx', 'torch', 'hg', 'onnx_custom'])
def make_flow(port_generator, request):
if request.param == 'onnx':
from clip_server.executors.clip_onnx import CLIPEncoder
elif request.param == 'torch':
from clip_server.executors.clip_torch import CLIPEncoder
if request.param != 'onnx_custom':
if request.param == 'onnx':
from clip_server.executors.clip_onnx import CLIPEncoder
elif request.param == 'torch':
from clip_server.executors.clip_torch import CLIPEncoder
else:
from clip_server.executors.clip_hg import CLIPEncoder

f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder)
else:
from clip_server.executors.clip_hg import CLIPEncoder
import os
from clip_server.executors.clip_onnx import CLIPEncoder

f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder)
f = Flow(port=port_generator()).add(
name=request.param,
uses=CLIPEncoder,
uses_with={'model_path': os.path.expanduser('~/.cache/clip/ViT-B-32')},
)
numb3r3 marked this conversation as resolved.
Show resolved Hide resolved
with f:
yield f

Expand Down
36 changes: 35 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import pytest
from clip_server.model.clip import _transform_ndarray, _transform_blob, _download
from docarray import Document
from jina import Flow
import numpy as np


def test_server_download(tmpdir):
_download('https://docarray.jina.ai/_static/favicon.png', tmpdir, with_resume=False)

target_path = os.path.join(tmpdir, 'favicon.png')
file_size = os.path.getsize(target_path)
assert file_size > 0
Expand All @@ -25,6 +25,40 @@ def test_server_download(tmpdir):
assert not os.path.exists(part_path)


def test_make_onnx_flow_custom_path_wrong_name(port_generator):
from clip_server.executors.clip_onnx import CLIPEncoder
import os

f = Flow(port=port_generator()).add(
name='onnx',
uses=CLIPEncoder,
uses_with={
'name': 'ABC',
'model_path': os.path.expanduser('~/.cache/clip/ViT-B-32'),
},
)
with pytest.raises(Exception) as info:
with f:
f.post('/', Document(text='Hello world'))


def test_make_onnx_flow_custom_path_wrong_path(port_generator):
from clip_server.executors.clip_onnx import CLIPEncoder
import os

f = Flow(port=port_generator()).add(
name='onnx',
uses=CLIPEncoder,
uses_with={
'name': 'ViT-B/32',
'model_path': 'ABC',
},
)
with pytest.raises(Exception) as info:
with f:
f.post('/', Document(text='Hello world'))


@pytest.mark.parametrize(
'image_uri',
[
Expand Down