Skip to content

Commit

Permalink
feat: support custom onnx file and update model signatures (#761)
Browse files Browse the repository at this point in the history
* feat: allow custom onnx file

* fix: path name

* fix: validate model path

* chore: improve error message

* test: add custom path unit test

* test: add test cases

* test: add test cases

* test: add test cases

* fix: reindent

* fix: change type to int32

* fix: modify text input

* chore: format code

* chore: update model links

* fix: update links

* fix: typo

* fix: add attention mask for onnx

* fix: trt text encode key

* fix: fix trt shape

* fix: trt convert

* fix: trt convert

* fix: tensorrt parse model

* fix: add md5 verification

* fix: add md5 verification

* feat: add md5 validation

* feat: add torch md5

* feat: add torch md5

* feat: add onnx md5

* fix: md5 validation

* chore: clean up

* fix: typo

* fix: typo

* fix: typo

* fix: correct path

* fix: trt path

* test: add md5 test

* test: add path test

* fix: house keeping

* fix: house keeping

* fix: house keeping

* fix: md5 test case

* fix: modify visual signature

* fix: modify visual signature

* fix: improve download retry

* fix: trt timeout 30 min

* fix: modify download logic

* docs: update trt

* fix: validation

* fix: polish download with md5

* fix: polish download with md5

* fix: stop with max retires

* fix: use forloop

* test: none regular file

Co-authored-by: numb3r3 <[email protected]>
  • Loading branch information
ZiniuYu and numb3r3 authored Jul 12, 2022
1 parent ed1b92d commit ee7da10
Show file tree
Hide file tree
Showing 11 changed files with 363 additions and 182 deletions.
4 changes: 3 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def setup(app):
)
app.add_config_value(
name='server_address',
default=os.getenv('JINA_DOCSBOT_SERVER', 'https://jina-ai-clip-as-service.docsqa.jina.ai'),
default=os.getenv(
'JINA_DOCSBOT_SERVER', 'https://jina-ai-clip-as-service.docsqa.jina.ai'
),
rebuild='',
)
app.connect('builder-inited', configure_qa_bot_ui)
4 changes: 2 additions & 2 deletions docs/user-guides/server.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ The procedure and UI of ONNX and TensorRT runtime would look the same as Pytorch

## Model support

Open AI has released 9 models so far. `ViT-B/32` is used as default model in all runtimes. Due to the limitation of some runtime, not every runtime supports all nine models. Please also note that different model give different size of output dimensions. This will affect your downstream applications. For example, switching the model from one to another make your embedding incomparable, which breaks the downstream applications. Below is a list of supported models of each runtime and its corresponding size. We include the disk usage (in delta) and the peak RAM and VRAM usage (in delta) when running on a single Nvidia TITAN RTX GPU (24GB VRAM) using a default `minibatch_size=32` in server and a default `batch_size=8` in client.
Open AI has released 9 models so far. `ViT-B/32` is used as default model in all runtimes. Due to the limitation of some runtime, not every runtime supports all nine models. Please also note that different model give different size of output dimensions. This will affect your downstream applications. For example, switching the model from one to another make your embedding incomparable, which breaks the downstream applications. Below is a list of supported models of each runtime and its corresponding size. We include the disk usage (in delta) and the peak RAM and VRAM usage (in delta) when running on a single Nvidia TITAN RTX GPU (24GB VRAM) using a default `minibatch_size=32` in server with PyTorch runtime and a default `batch_size=8` in client.

| Model | PyTorch | ONNX | TensorRT | Output Dimension | Disk Usage (MB) | Peak RAM Usage (GB) | Peak VRAM Usage (GB) |
|----------------|---------|------|----------|------------------|-----------------|---------------------|----------------------|
Expand All @@ -72,7 +72,7 @@ Open AI has released 9 models so far. `ViT-B/32` is used as default model in all
| RN50x64 |||| 1024 | 1382 | 4.08 | 2.98 |
| ViT-B/32 |||| 512 | 351 | 3.20 | 1.40 |
| ViT-B/16 |||| 512 | 354 | 3.20 | 1.44 |
| ViT-L/14 ||| | 768 | 933 | 3.66 | 2.04 |
| ViT-L/14 ||| | 768 | 933 | 3.66 | 2.04 |
| ViT-L/14-336px |||| 768 | 934 | 3.74 | 2.23 |


Expand Down
3 changes: 2 additions & 1 deletion server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
traversal_paths: str = '@r',
model_path: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -33,7 +34,7 @@ def __init__(
self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._pool = ThreadPool(processes=num_worker_preprocess)

self._model = CLIPOnnxModel(name)
self._model = CLIPOnnxModel(name, model_path)

import torch

Expand Down
4 changes: 2 additions & 2 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
documentation='images encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_image(batch_data)
self._model.encode_image(batch_data['pixel_values'])
.cpu()
.numpy()
.astype(np.float32)
Expand All @@ -126,7 +126,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
documentation='texts encode time in seconds',
):
minibatch.embeddings = (
self._model.encode_text(batch_data)
self._model.encode_text(batch_data['input_ids'])
.cpu()
.numpy()
.astype(np.float32)
Expand Down
21 changes: 13 additions & 8 deletions server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, List, Callable, Any
from typing import Tuple, List, Callable, Any, Dict
import torch
import numpy as np
from docarray import Document, DocumentArray
Expand All @@ -20,7 +20,7 @@ def preproc_image(
preprocess_fn: Callable,
device: str = 'cpu',
return_np: bool = False,
) -> Tuple['DocumentArray', List[Any]]:
) -> Tuple['DocumentArray', Dict]:

tensors_batch = []

Expand All @@ -45,22 +45,27 @@ def preproc_image(
else:
tensors_batch = tensors_batch.to(device)

return da, tensors_batch
return da, {'pixel_values': tensors_batch}


def preproc_text(
da: 'DocumentArray', device: str = 'cpu', return_np: bool = False
) -> Tuple['DocumentArray', List[Any]]:
) -> Tuple['DocumentArray', Dict]:

tensors_batch = clip.tokenize(da.texts).detach()
inputs = clip.tokenize(da.texts)
inputs['input_ids'] = inputs['input_ids'].detach()

if return_np:
tensors_batch = tensors_batch.cpu().numpy().astype(np.int64)
inputs['input_ids'] = inputs['input_ids'].cpu().numpy().astype(np.int32)
inputs['attention_mask'] = (
inputs['attention_mask'].cpu().numpy().astype(np.int32)
)
else:
tensors_batch = tensors_batch.to(device)
inputs['input_ids'] = inputs['input_ids'].to(device)
inputs['attention_mask'] = inputs['attention_mask'].to(device)

da[:, 'mime_type'] = 'text'
return da, tensors_batch
return da, inputs


def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'):
Expand Down
162 changes: 96 additions & 66 deletions server/clip_server/model/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io
import os
import hashlib
import shutil
import urllib
import warnings
Expand All @@ -26,15 +27,15 @@

_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch/'
_MODELS = {
'RN50': 'RN50.pt',
'RN101': 'RN101.pt',
'RN50x4': 'RN50x4.pt',
'RN50x16': 'RN50x16.pt',
'RN50x64': 'RN50x64.pt',
'ViT-B/32': 'ViT-B-32.pt',
'ViT-B/16': 'ViT-B-16.pt',
'ViT-L/14': 'ViT-L-14.pt',
'ViT-L/14@336px': 'ViT-L-14-336px.pt',
'RN50': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'),
'RN101': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'),
'RN50x4': ('RN50x4.pt', '03830990bc768e82f7fb684cde7e5654'),
'RN50x16': ('RN50x16.pt', '83d63878a818c65d0fb417e5fab1e8fe'),
'RN50x64': ('RN50x64.pt', 'a6631a0de003c4075d286140fc6dd637'),
'ViT-B/32': ('ViT-B-32.pt', '3ba34e387b24dfe590eeb1ae6a8a122b'),
'ViT-B/16': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'),
'ViT-L/14': ('ViT-L-14.pt', '096db1af569b284eb76b3881534822d9'),
'ViT-L/14@336px': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'),
}

MODEL_SIZE = {
Expand All @@ -50,16 +51,34 @@
}


def _download(url: str, root: str, with_resume: bool = True):
os.makedirs(root, exist_ok=True)
def md5file(filename: str):
hash_md5 = hashlib.md5()
with open(filename, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)

return hash_md5.hexdigest()


def _download(
url: str,
target_folder: str,
md5sum: str = None,
with_resume: bool = True,
max_attempts: int = 3,
) -> str:
os.makedirs(target_folder, exist_ok=True)
filename = os.path.basename(url)

download_target = os.path.join(root, filename)
if os.path.isfile(download_target):
return download_target
download_target = os.path.join(target_folder, filename)

if os.path.exists(download_target) and not os.path.isfile(download_target):
raise FileExistsError(f'{download_target} exists and is not a regular file')
if os.path.exists(download_target):
if not os.path.isfile(download_target):
raise FileExistsError(f'{download_target} exists and is not a regular file')

actual_md5sum = md5file(download_target)
if (not md5sum) or actual_md5sum == md5sum:
return download_target

from rich.progress import (
DownloadColumn,
Expand All @@ -81,53 +100,58 @@ def _download(url: str, root: str, with_resume: bool = True):
)

with progress:

task = progress.add_task('download', filename=url, start=False)

tmp_file_path = download_target + '.part'
resume_byte_pos = (
os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0
)

total_bytes = -1
try:
# resolve the 403 error by passing a valid user-agent
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})

total_bytes = int(
urllib.request.urlopen(req).info().get('Content-Length', -1)
for _ in range(max_attempts):
tmp_file_path = download_target + '.part'
resume_byte_pos = (
os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0
)

mode = 'ab' if (with_resume and resume_byte_pos) else 'wb'

with open(tmp_file_path, mode) as output:

progress.update(task, total=total_bytes)

progress.start_task(task)

if resume_byte_pos and with_resume:
progress.update(task, advance=resume_byte_pos)
req.headers['Range'] = f'bytes={resume_byte_pos}-'

with urllib.request.urlopen(req) as source:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
progress.update(task, advance=len(buffer))
except Exception as ex:
raise ex
finally:
# rename the temp download file to the correct name if fully downloaded
if os.path.exists(tmp_file_path) and (
total_bytes == os.path.getsize(tmp_file_path)
):
shutil.move(tmp_file_path, download_target)
try:
# resolve the 403 error by passing a valid user-agent
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
total_bytes = int(
urllib.request.urlopen(req).info().get('Content-Length', -1)
)
mode = 'ab' if (with_resume and resume_byte_pos) else 'wb'

with open(tmp_file_path, mode) as output:
progress.update(task, total=total_bytes)
progress.start_task(task)

if resume_byte_pos and with_resume:
progress.update(task, advance=resume_byte_pos)
req.headers['Range'] = f'bytes={resume_byte_pos}-'

with urllib.request.urlopen(req) as source:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
progress.update(task, advance=len(buffer))

actual_md5 = md5file(tmp_file_path)
if (md5sum and actual_md5 == md5sum) or (not md5sum):
shutil.move(tmp_file_path, download_target)
return download_target
else:
os.remove(tmp_file_path)
raise RuntimeError(
f'MD5 mismatch: expected {md5sum}, got {actual_md5}'
)

except Exception as ex:
progress.console.print(
f'Failed to download {url} with {ex!r} at the {_}th attempt'
)
progress.reset(task)

return download_target
raise RuntimeError(
f'Failed to download {url} within retry limit {max_attempts}'
)


def _convert_image_to_rgb(image):
Expand Down Expand Up @@ -193,7 +217,7 @@ def load(
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses '~/.cache/clip'
path to download the model files; by default, it uses '~/.cache/clip/'
Returns
-------
Expand All @@ -204,9 +228,11 @@ def load(
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_name, model_md5 = _MODELS[name]
model_path = _download(
_S3_BUCKET + _MODELS[name],
download_root or os.path.expanduser('~/.cache/clip'),
url=_S3_BUCKET + model_name,
target_folder=download_root or os.path.expanduser('~/.cache/clip'),
md5sum=model_md5,
with_resume=True,
)
elif os.path.isfile(name):
Expand Down Expand Up @@ -309,7 +335,7 @@ def patch_float(module):

def tokenize(
texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True
) -> torch.LongTensor:
) -> dict:
"""
Returns the tokenized representation of given input string(s)
Expand All @@ -326,15 +352,18 @@ def tokenize(
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
A dict of tokenized representations of the input strings and their corresponding attention masks with both
shape = [batch size, context_length]
"""
if isinstance(texts, str):
texts = [texts]

sot_token = _tokenizer.encoder['<|startoftext|>']
eot_token = _tokenizer.encoder['<|endoftext|>']
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

input_ids = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
attention_mask = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
Expand All @@ -345,6 +374,7 @@ def tokenize(
raise RuntimeError(
f'Input {texts[i]} is too long for context length {context_length}'
)
result[i, : len(tokens)] = torch.tensor(tokens)
input_ids[i, : len(tokens)] = torch.tensor(tokens)
attention_mask[i, : len(tokens)] = 1

return result
return {'input_ids': input_ids, 'attention_mask': attention_mask}
Loading

0 comments on commit ee7da10

Please sign in to comment.