Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support custom onnx file and update model signatures #761

Merged
merged 52 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
7fbb041
feat: allow custom onnx file
ZiniuYu Jun 27, 2022
10dd4f5
fix: path name
ZiniuYu Jun 27, 2022
bcad79c
fix: validate model path
ZiniuYu Jun 28, 2022
1e8ea3b
chore: improve error message
ZiniuYu Jun 28, 2022
9039b8f
test: add custom path unit test
ZiniuYu Jun 28, 2022
5c874a1
test: add test cases
ZiniuYu Jun 28, 2022
6413f74
test: add test cases
ZiniuYu Jun 28, 2022
8cff7ba
test: add test cases
ZiniuYu Jun 28, 2022
7affea9
fix: reindent
ZiniuYu Jun 28, 2022
fc3a41f
fix: change type to int32
ZiniuYu Jun 30, 2022
5592455
fix: modify text input
ZiniuYu Jul 1, 2022
14da62b
chore: format code
ZiniuYu Jul 1, 2022
1dd7589
chore: update model links
ZiniuYu Jul 1, 2022
6802d25
fix: update links
ZiniuYu Jul 1, 2022
6242203
fix: typo
ZiniuYu Jul 1, 2022
3b9e917
fix: add attention mask for onnx
ZiniuYu Jul 1, 2022
e704242
fix: trt text encode key
ZiniuYu Jul 4, 2022
1025ffd
fix: fix trt shape
ZiniuYu Jul 4, 2022
55a18ca
fix: trt convert
ZiniuYu Jul 5, 2022
4ac14b8
fix: trt convert
ZiniuYu Jul 5, 2022
f748e26
fix: tensorrt parse model
ZiniuYu Jul 5, 2022
22d8a90
fix: add md5 verification
ZiniuYu Jul 5, 2022
5b0a251
fix: add md5 verification
ZiniuYu Jul 6, 2022
9ab65cc
feat: add md5 validation
ZiniuYu Jul 6, 2022
3174c8e
feat: add torch md5
ZiniuYu Jul 6, 2022
bb65ffc
feat: add torch md5
ZiniuYu Jul 6, 2022
1e7bb3f
feat: add onnx md5
ZiniuYu Jul 6, 2022
3380e89
fix: md5 validation
ZiniuYu Jul 6, 2022
0f80a95
chore: clean up
ZiniuYu Jul 6, 2022
33217d6
fix: typo
ZiniuYu Jul 6, 2022
66b49f7
fix: typo
ZiniuYu Jul 7, 2022
aaf25e0
fix: typo
ZiniuYu Jul 7, 2022
43ba806
fix: correct path
ZiniuYu Jul 7, 2022
7df5b2b
fix: trt path
ZiniuYu Jul 7, 2022
916ab5c
test: add md5 test
ZiniuYu Jul 7, 2022
8ab7e89
test: add path test
ZiniuYu Jul 7, 2022
d20785b
fix: house keeping
ZiniuYu Jul 7, 2022
3bd3b5d
fix: house keeping
ZiniuYu Jul 7, 2022
f3a059e
fix: house keeping
ZiniuYu Jul 7, 2022
21c857c
fix: md5 test case
ZiniuYu Jul 7, 2022
8d63465
fix: modify visual signature
ZiniuYu Jul 7, 2022
7868308
fix: modify visual signature
ZiniuYu Jul 7, 2022
25c8fdc
fix: improve download retry
ZiniuYu Jul 7, 2022
4e02508
fix: trt timeout 30 min
ZiniuYu Jul 7, 2022
dfcd82c
fix: modify download logic
ZiniuYu Jul 10, 2022
6bea114
docs: update trt
ZiniuYu Jul 10, 2022
cd631c5
fix: validation
ZiniuYu Jul 11, 2022
e63cb46
fix: polish download with md5
numb3r3 Jul 11, 2022
45fb0f6
fix: polish download with md5
numb3r3 Jul 11, 2022
f6facd0
fix: stop with max retires
numb3r3 Jul 11, 2022
56d0dd3
fix: use forloop
numb3r3 Jul 11, 2022
94008ac
test: none regular file
ZiniuYu Jul 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def setup(app):
)
app.add_config_value(
name='server_address',
default=os.getenv('JINA_DOCSBOT_SERVER', 'https://jina-ai-clip-as-service.docsqa.jina.ai'),
default=os.getenv(
'JINA_DOCSBOT_SERVER', 'https://jina-ai-clip-as-service.docsqa.jina.ai'
),
rebuild='',
)
app.connect('builder-inited', configure_qa_bot_ui)
3 changes: 2 additions & 1 deletion server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
traversal_paths: str = '@r',
model_path: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -33,7 +34,7 @@ def __init__(
self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._pool = ThreadPool(processes=num_worker_preprocess)

self._model = CLIPOnnxModel(name)
self._model = CLIPOnnxModel(name, model_path)

import torch

Expand Down
4 changes: 2 additions & 2 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
documentation='images encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_image(batch_data)
self._model.encode_image(batch_data['pixel_values'])
.cpu()
.numpy()
.astype(np.float32)
Expand All @@ -126,7 +126,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
documentation='texts encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_text(batch_data)
self._model.encode_text(batch_data['input_ids'])
.cpu()
.numpy()
.astype(np.float32)
Expand Down
21 changes: 13 additions & 8 deletions server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, List, Callable, Any
from typing import Tuple, List, Callable, Any, Dict
import torch
import numpy as np
from docarray import Document, DocumentArray
Expand All @@ -20,7 +20,7 @@ def preproc_image(
preprocess_fn: Callable,
device: str = 'cpu',
return_np: bool = False,
) -> Tuple['DocumentArray', List[Any]]:
) -> Tuple['DocumentArray', Dict]:

tensors_batch = []

Expand All @@ -45,22 +45,27 @@ def preproc_image(
else:
tensors_batch = tensors_batch.to(device)

return da, tensors_batch
return da, {'pixel_values': tensors_batch}


def preproc_text(
da: 'DocumentArray', device: str = 'cpu', return_np: bool = False
) -> Tuple['DocumentArray', List[Any]]:
) -> Tuple['DocumentArray', Dict]:

tensors_batch = clip.tokenize(da.texts).detach()
inputs = clip.tokenize(da.texts)
inputs['input_ids'] = inputs['input_ids'].detach()

if return_np:
tensors_batch = tensors_batch.cpu().numpy().astype(np.int64)
inputs['input_ids'] = inputs['input_ids'].cpu().numpy().astype(np.int32)
inputs['attention_mask'] = (
inputs['attention_mask'].cpu().numpy().astype(np.int32)
)
else:
tensors_batch = tensors_batch.to(device)
inputs['input_ids'] = inputs['input_ids'].to(device)
inputs['attention_mask'] = inputs['attention_mask'].to(device)

da[:, 'mime_type'] = 'text'
return da, tensors_batch
return da, inputs


def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'):
Expand Down
154 changes: 94 additions & 60 deletions server/clip_server/model/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io
import os
import hashlib
import shutil
import urllib
import warnings
Expand All @@ -26,15 +27,15 @@

_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch/'
_MODELS = {
'RN50': 'RN50.pt',
'RN101': 'RN101.pt',
'RN50x4': 'RN50x4.pt',
'RN50x16': 'RN50x16.pt',
'RN50x64': 'RN50x64.pt',
'ViT-B/32': 'ViT-B-32.pt',
'ViT-B/16': 'ViT-B-16.pt',
'ViT-L/14': 'ViT-L-14.pt',
'ViT-L/14@336px': 'ViT-L-14-336px.pt',
'RN50': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'),
'RN101': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'),
'RN50x4': ('RN50x4.pt', '03830990bc768e82f7fb684cde7e5654'),
'RN50x16': ('RN50x16.pt', '83d63878a818c65d0fb417e5fab1e8fe'),
'RN50x64': ('RN50x64.pt', 'a6631a0de003c4075d286140fc6dd637'),
'ViT-B/32': ('ViT-B-32.pt', '3ba34e387b24dfe590eeb1ae6a8a122b'),
'ViT-B/16': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'),
'ViT-L/14': ('ViT-L-14.pt', '096db1af569b284eb76b3881534822d9'),
'ViT-L/14@336px': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'),
}

MODEL_SIZE = {
Expand All @@ -50,12 +51,21 @@
}


def _download(url: str, root: str, with_resume: bool = True):
def _download(
url: str,
root: str,
md5: str = None,
with_resume: bool = True,
max_attempts: int = 3,
) -> str:

os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)

download_target = os.path.join(root, filename)
if os.path.isfile(download_target):
if os.path.isfile(download_target) and (
not md5 or hashlib.md5(open(download_target, 'rb').read()).hexdigest() == md5
):
return download_target

if os.path.exists(download_target) and not os.path.isfile(download_target):
Expand All @@ -81,53 +91,71 @@ def _download(url: str, root: str, with_resume: bool = True):
)

with progress:

task = progress.add_task('download', filename=url, start=False)
for _ in range(max_attempts):

tmp_file_path = download_target + '.part'
resume_byte_pos = (
os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0
)

total_bytes = -1
try:
# resolve the 403 error by passing a valid user-agent
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})

total_bytes = int(
urllib.request.urlopen(req).info().get('Content-Length', -1)
tmp_file_path = download_target + '.part'
resume_byte_pos = (
os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0
)

mode = 'ab' if (with_resume and resume_byte_pos) else 'wb'

with open(tmp_file_path, mode) as output:

progress.update(task, total=total_bytes)

progress.start_task(task)

if resume_byte_pos and with_resume:
progress.update(task, advance=resume_byte_pos)
req.headers['Range'] = f'bytes={resume_byte_pos}-'

with urllib.request.urlopen(req) as source:
while True:
buffer = source.read(8192)
if not buffer:
break
total_bytes = -1
try:
# resolve the 403 error by passing a valid user-agent
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})

output.write(buffer)
progress.update(task, advance=len(buffer))
except Exception as ex:
raise ex
finally:
# rename the temp download file to the correct name if fully downloaded
if os.path.exists(tmp_file_path) and (
total_bytes == os.path.getsize(tmp_file_path)
):
shutil.move(tmp_file_path, download_target)
total_bytes = int(
urllib.request.urlopen(req).info().get('Content-Length', -1)
)

return download_target
mode = 'ab' if (with_resume and resume_byte_pos) else 'wb'

with open(tmp_file_path, mode) as output:

progress.update(task, total=total_bytes)

progress.start_task(task)

if resume_byte_pos and with_resume:
progress.update(task, advance=resume_byte_pos)
req.headers['Range'] = f'bytes={resume_byte_pos}-'

with urllib.request.urlopen(req) as source:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
progress.update(task, advance=len(buffer))
except Exception as ex:
raise ex
ZiniuYu marked this conversation as resolved.
Show resolved Hide resolved
finally:
# rename the temp download file to the correct name if fully downloaded
if os.path.exists(tmp_file_path) and (
total_bytes == os.path.getsize(tmp_file_path)
and (
not md5
or hashlib.md5(open(tmp_file_path, 'rb').read()).hexdigest()
== md5
)
ZiniuYu marked this conversation as resolved.
Show resolved Hide resolved
):
shutil.move(tmp_file_path, download_target)
break
else:
progress.console.print(
f'MD5 mismatch for {download_target}, maybe the file is not completely downloaded. '
ZiniuYu marked this conversation as resolved.
Show resolved Hide resolved
f'Retrying now...'
)
os.remove(tmp_file_path)
progress.reset(task)

if os.path.isfile(download_target) and (
not md5 or hashlib.md5(open(download_target, 'rb').read()).hexdigest() == md5
ZiniuYu marked this conversation as resolved.
Show resolved Hide resolved
):
return download_target
else:
raise RuntimeError(f'Failed to download {url}, max attempts exceeded')


def _convert_image_to_rgb(image):
Expand Down Expand Up @@ -193,7 +221,7 @@ def load(
Whether to load the optimized JIT model or more hackable non-JIT model (default).

download_root: str
path to download the model files; by default, it uses '~/.cache/clip'
path to download the model files; by default, it uses '~/.cache/clip/'

Returns
-------
Expand All @@ -204,9 +232,11 @@ def load(
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_name, model_md5 = _MODELS[name]
model_path = _download(
numb3r3 marked this conversation as resolved.
Show resolved Hide resolved
_S3_BUCKET + _MODELS[name],
download_root or os.path.expanduser('~/.cache/clip'),
url=_S3_BUCKET + model_name,
root=download_root or os.path.expanduser('~/.cache/clip'),
md5=model_md5,
with_resume=True,
)
elif os.path.isfile(name):
Expand Down Expand Up @@ -309,7 +339,7 @@ def patch_float(module):

def tokenize(
texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True
) -> torch.LongTensor:
) -> dict:
"""
Returns the tokenized representation of given input string(s)

Expand All @@ -326,15 +356,18 @@ def tokenize(

Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
A dict of tokenized representations of the input strings and their corresponding attention masks with both
shape = [batch size, context_length]
"""
if isinstance(texts, str):
texts = [texts]

sot_token = _tokenizer.encoder['<|startoftext|>']
eot_token = _tokenizer.encoder['<|endoftext|>']
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

input_ids = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
attention_mask = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
Expand All @@ -345,6 +378,7 @@ def tokenize(
raise RuntimeError(
f'Input {texts[i]} is too long for context length {context_length}'
)
result[i, : len(tokens)] = torch.tensor(tokens)
input_ids[i, : len(tokens)] = torch.tensor(tokens)
attention_mask[i, : len(tokens)] = 1

return result
return {'input_ids': input_ids, 'attention_mask': attention_mask}
Loading