From ee7da10d1f56a130e6f9a85d5fb3518b80e5df0d Mon Sep 17 00:00:00 2001 From: Ziniu Yu Date: Tue, 12 Jul 2022 11:21:24 +0800 Subject: [PATCH] feat: support custom onnx file and update model signatures (#761) * feat: allow custom onnx file * fix: path name * fix: validate model path * chore: improve error message * test: add custom path unit test * test: add test cases * test: add test cases * test: add test cases * fix: reindent * fix: change type to int32 * fix: modify text input * chore: format code * chore: update model links * fix: update links * fix: typo * fix: add attention mask for onnx * fix: trt text encode key * fix: fix trt shape * fix: trt convert * fix: trt convert * fix: tensorrt parse model * fix: add md5 verification * fix: add md5 verification * feat: add md5 validation * feat: add torch md5 * feat: add torch md5 * feat: add onnx md5 * fix: md5 validation * chore: clean up * fix: typo * fix: typo * fix: typo * fix: correct path * fix: trt path * test: add md5 test * test: add path test * fix: house keeping * fix: house keeping * fix: house keeping * fix: md5 test case * fix: modify visual signature * fix: modify visual signature * fix: improve download retry * fix: trt timeout 30 min * fix: modify download logic * docs: update trt * fix: validation * fix: polish download with md5 * fix: polish download with md5 * fix: stop with max retires * fix: use forloop * test: none regular file Co-authored-by: numb3r3 --- docs/conf.py | 4 +- docs/user-guides/server.md | 4 +- server/clip_server/executors/clip_onnx.py | 3 +- server/clip_server/executors/clip_torch.py | 4 +- server/clip_server/executors/helper.py | 21 ++- server/clip_server/model/clip.py | 162 ++++++++++++--------- server/clip_server/model/clip_onnx.py | 97 +++++++++--- server/clip_server/model/clip_trt.py | 141 +++++++++--------- server/clip_server/tensorrt-flow.yml | 1 + tests/conftest.py | 24 ++- tests/test_server.py | 84 ++++++++++- 11 files changed, 363 insertions(+), 182 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 07a590588..b183f8ac7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -234,7 +234,9 @@ def setup(app): ) app.add_config_value( name='server_address', - default=os.getenv('JINA_DOCSBOT_SERVER', 'https://jina-ai-clip-as-service.docsqa.jina.ai'), + default=os.getenv( + 'JINA_DOCSBOT_SERVER', 'https://jina-ai-clip-as-service.docsqa.jina.ai' + ), rebuild='', ) app.connect('builder-inited', configure_qa_bot_ui) diff --git a/docs/user-guides/server.md b/docs/user-guides/server.md index cb8e15b19..28e46c934 100644 --- a/docs/user-guides/server.md +++ b/docs/user-guides/server.md @@ -61,7 +61,7 @@ The procedure and UI of ONNX and TensorRT runtime would look the same as Pytorch ## Model support -Open AI has released 9 models so far. `ViT-B/32` is used as default model in all runtimes. Due to the limitation of some runtime, not every runtime supports all nine models. Please also note that different model give different size of output dimensions. This will affect your downstream applications. For example, switching the model from one to another make your embedding incomparable, which breaks the downstream applications. Below is a list of supported models of each runtime and its corresponding size. We include the disk usage (in delta) and the peak RAM and VRAM usage (in delta) when running on a single Nvidia TITAN RTX GPU (24GB VRAM) using a default `minibatch_size=32` in server and a default `batch_size=8` in client. +Open AI has released 9 models so far. `ViT-B/32` is used as default model in all runtimes. Due to the limitation of some runtime, not every runtime supports all nine models. Please also note that different model give different size of output dimensions. This will affect your downstream applications. For example, switching the model from one to another make your embedding incomparable, which breaks the downstream applications. Below is a list of supported models of each runtime and its corresponding size. We include the disk usage (in delta) and the peak RAM and VRAM usage (in delta) when running on a single Nvidia TITAN RTX GPU (24GB VRAM) using a default `minibatch_size=32` in server with PyTorch runtime and a default `batch_size=8` in client. | Model | PyTorch | ONNX | TensorRT | Output Dimension | Disk Usage (MB) | Peak RAM Usage (GB) | Peak VRAM Usage (GB) | |----------------|---------|------|----------|------------------|-----------------|---------------------|----------------------| @@ -72,7 +72,7 @@ Open AI has released 9 models so far. `ViT-B/32` is used as default model in all | RN50x64 | ✅ | ✅ | ❌ | 1024 | 1382 | 4.08 | 2.98 | | ViT-B/32 | ✅ | ✅ | ✅ | 512 | 351 | 3.20 | 1.40 | | ViT-B/16 | ✅ | ✅ | ✅ | 512 | 354 | 3.20 | 1.44 | -| ViT-L/14 | ✅ | ✅ | ✅ | 768 | 933 | 3.66 | 2.04 | +| ViT-L/14 | ✅ | ✅ | ❌ | 768 | 933 | 3.66 | 2.04 | | ViT-L/14-336px | ✅ | ✅ | ❌ | 768 | 934 | 3.74 | 2.23 | diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index deed22328..661f55908 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -23,6 +23,7 @@ def __init__( num_worker_preprocess: int = 4, minibatch_size: int = 32, traversal_paths: str = '@r', + model_path: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -33,7 +34,7 @@ def __init__( self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name]) self._pool = ThreadPool(processes=num_worker_preprocess) - self._model = CLIPOnnxModel(name) + self._model = CLIPOnnxModel(name, model_path) import torch diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index a4701004a..59173307e 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -108,7 +108,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): documentation='images encode time in seconds', ): minibatch.embeddings = ( - self._model.encode_image(batch_data) + self._model.encode_image(batch_data['pixel_values']) .cpu() .numpy() .astype(np.float32) @@ -126,7 +126,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): documentation='texts encode time in seconds', ): minibatch.embeddings = ( - self._model.encode_text(batch_data) + self._model.encode_text(batch_data['input_ids']) .cpu() .numpy() .astype(np.float32) diff --git a/server/clip_server/executors/helper.py b/server/clip_server/executors/helper.py index 4e1ddecb3..9ecb7238c 100644 --- a/server/clip_server/executors/helper.py +++ b/server/clip_server/executors/helper.py @@ -1,4 +1,4 @@ -from typing import Tuple, List, Callable, Any +from typing import Tuple, List, Callable, Any, Dict import torch import numpy as np from docarray import Document, DocumentArray @@ -20,7 +20,7 @@ def preproc_image( preprocess_fn: Callable, device: str = 'cpu', return_np: bool = False, -) -> Tuple['DocumentArray', List[Any]]: +) -> Tuple['DocumentArray', Dict]: tensors_batch = [] @@ -45,22 +45,27 @@ def preproc_image( else: tensors_batch = tensors_batch.to(device) - return da, tensors_batch + return da, {'pixel_values': tensors_batch} def preproc_text( da: 'DocumentArray', device: str = 'cpu', return_np: bool = False -) -> Tuple['DocumentArray', List[Any]]: +) -> Tuple['DocumentArray', Dict]: - tensors_batch = clip.tokenize(da.texts).detach() + inputs = clip.tokenize(da.texts) + inputs['input_ids'] = inputs['input_ids'].detach() if return_np: - tensors_batch = tensors_batch.cpu().numpy().astype(np.int64) + inputs['input_ids'] = inputs['input_ids'].cpu().numpy().astype(np.int32) + inputs['attention_mask'] = ( + inputs['attention_mask'].cpu().numpy().astype(np.int32) + ) else: - tensors_batch = tensors_batch.to(device) + inputs['input_ids'] = inputs['input_ids'].to(device) + inputs['attention_mask'] = inputs['attention_mask'].to(device) da[:, 'mime_type'] = 'text' - return da, tensors_batch + return da, inputs def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'): diff --git a/server/clip_server/model/clip.py b/server/clip_server/model/clip.py index 315003b99..9e2fae77e 100644 --- a/server/clip_server/model/clip.py +++ b/server/clip_server/model/clip.py @@ -2,6 +2,7 @@ import io import os +import hashlib import shutil import urllib import warnings @@ -26,15 +27,15 @@ _S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch/' _MODELS = { - 'RN50': 'RN50.pt', - 'RN101': 'RN101.pt', - 'RN50x4': 'RN50x4.pt', - 'RN50x16': 'RN50x16.pt', - 'RN50x64': 'RN50x64.pt', - 'ViT-B/32': 'ViT-B-32.pt', - 'ViT-B/16': 'ViT-B-16.pt', - 'ViT-L/14': 'ViT-L-14.pt', - 'ViT-L/14@336px': 'ViT-L-14-336px.pt', + 'RN50': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), + 'RN101': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'), + 'RN50x4': ('RN50x4.pt', '03830990bc768e82f7fb684cde7e5654'), + 'RN50x16': ('RN50x16.pt', '83d63878a818c65d0fb417e5fab1e8fe'), + 'RN50x64': ('RN50x64.pt', 'a6631a0de003c4075d286140fc6dd637'), + 'ViT-B/32': ('ViT-B-32.pt', '3ba34e387b24dfe590eeb1ae6a8a122b'), + 'ViT-B/16': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'), + 'ViT-L/14': ('ViT-L-14.pt', '096db1af569b284eb76b3881534822d9'), + 'ViT-L/14@336px': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'), } MODEL_SIZE = { @@ -50,16 +51,34 @@ } -def _download(url: str, root: str, with_resume: bool = True): - os.makedirs(root, exist_ok=True) +def md5file(filename: str): + hash_md5 = hashlib.md5() + with open(filename, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + + return hash_md5.hexdigest() + + +def _download( + url: str, + target_folder: str, + md5sum: str = None, + with_resume: bool = True, + max_attempts: int = 3, +) -> str: + os.makedirs(target_folder, exist_ok=True) filename = os.path.basename(url) - download_target = os.path.join(root, filename) - if os.path.isfile(download_target): - return download_target + download_target = os.path.join(target_folder, filename) - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise FileExistsError(f'{download_target} exists and is not a regular file') + if os.path.exists(download_target): + if not os.path.isfile(download_target): + raise FileExistsError(f'{download_target} exists and is not a regular file') + + actual_md5sum = md5file(download_target) + if (not md5sum) or actual_md5sum == md5sum: + return download_target from rich.progress import ( DownloadColumn, @@ -81,53 +100,58 @@ def _download(url: str, root: str, with_resume: bool = True): ) with progress: - task = progress.add_task('download', filename=url, start=False) - tmp_file_path = download_target + '.part' - resume_byte_pos = ( - os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0 - ) - - total_bytes = -1 - try: - # resolve the 403 error by passing a valid user-agent - req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) - - total_bytes = int( - urllib.request.urlopen(req).info().get('Content-Length', -1) + for _ in range(max_attempts): + tmp_file_path = download_target + '.part' + resume_byte_pos = ( + os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0 ) - mode = 'ab' if (with_resume and resume_byte_pos) else 'wb' - - with open(tmp_file_path, mode) as output: - - progress.update(task, total=total_bytes) - - progress.start_task(task) - - if resume_byte_pos and with_resume: - progress.update(task, advance=resume_byte_pos) - req.headers['Range'] = f'bytes={resume_byte_pos}-' - - with urllib.request.urlopen(req) as source: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - progress.update(task, advance=len(buffer)) - except Exception as ex: - raise ex - finally: - # rename the temp download file to the correct name if fully downloaded - if os.path.exists(tmp_file_path) and ( - total_bytes == os.path.getsize(tmp_file_path) - ): - shutil.move(tmp_file_path, download_target) + try: + # resolve the 403 error by passing a valid user-agent + req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + total_bytes = int( + urllib.request.urlopen(req).info().get('Content-Length', -1) + ) + mode = 'ab' if (with_resume and resume_byte_pos) else 'wb' + + with open(tmp_file_path, mode) as output: + progress.update(task, total=total_bytes) + progress.start_task(task) + + if resume_byte_pos and with_resume: + progress.update(task, advance=resume_byte_pos) + req.headers['Range'] = f'bytes={resume_byte_pos}-' + + with urllib.request.urlopen(req) as source: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + progress.update(task, advance=len(buffer)) + + actual_md5 = md5file(tmp_file_path) + if (md5sum and actual_md5 == md5sum) or (not md5sum): + shutil.move(tmp_file_path, download_target) + return download_target + else: + os.remove(tmp_file_path) + raise RuntimeError( + f'MD5 mismatch: expected {md5sum}, got {actual_md5}' + ) + + except Exception as ex: + progress.console.print( + f'Failed to download {url} with {ex!r} at the {_}th attempt' + ) + progress.reset(task) - return download_target + raise RuntimeError( + f'Failed to download {url} within retry limit {max_attempts}' + ) def _convert_image_to_rgb(image): @@ -193,7 +217,7 @@ def load( Whether to load the optimized JIT model or more hackable non-JIT model (default). download_root: str - path to download the model files; by default, it uses '~/.cache/clip' + path to download the model files; by default, it uses '~/.cache/clip/' Returns ------- @@ -204,9 +228,11 @@ def load( A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if name in _MODELS: + model_name, model_md5 = _MODELS[name] model_path = _download( - _S3_BUCKET + _MODELS[name], - download_root or os.path.expanduser('~/.cache/clip'), + url=_S3_BUCKET + model_name, + target_folder=download_root or os.path.expanduser('~/.cache/clip'), + md5sum=model_md5, with_resume=True, ) elif os.path.isfile(name): @@ -309,7 +335,7 @@ def patch_float(module): def tokenize( texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True -) -> torch.LongTensor: +) -> dict: """ Returns the tokenized representation of given input string(s) @@ -326,7 +352,8 @@ def tokenize( Returns ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + A dict of tokenized representations of the input strings and their corresponding attention masks with both + shape = [batch size, context_length] """ if isinstance(texts, str): texts = [texts] @@ -334,7 +361,9 @@ def tokenize( sot_token = _tokenizer.encoder['<|startoftext|>'] eot_token = _tokenizer.encoder['<|endoftext|>'] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + input_ids = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + attention_mask = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: @@ -345,6 +374,7 @@ def tokenize( raise RuntimeError( f'Input {texts[i]} is too long for context length {context_length}' ) - result[i, : len(tokens)] = torch.tensor(tokens) + input_ids[i, : len(tokens)] = torch.tensor(tokens) + attention_mask[i, : len(tokens)] = 1 - return result + return {'input_ids': input_ids, 'attention_mask': attention_mask} diff --git a/server/clip_server/model/clip_onnx.py b/server/clip_server/model/clip_onnx.py index 9326bcaaa..1bbb8f57b 100644 --- a/server/clip_server/model/clip_onnx.py +++ b/server/clip_server/model/clip_onnx.py @@ -2,30 +2,85 @@ from clip_server.model.clip import _download, available_models -_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/' +_S3_BUCKET = ( + 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/' # Deprecated +) +_S3_BUCKET_V2 = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models-436c69702d61732d53657276696365/onnx/' _MODELS = { - 'RN50': ('RN50/textual.onnx', 'RN50/visual.onnx'), - 'RN101': ('RN101/textual.onnx', 'RN101/visual.onnx'), - 'RN50x4': ('RN50x4/textual.onnx', 'RN50x4/visual.onnx'), - 'RN50x16': ('RN50x16/textual.onnx', 'RN50x16/visual.onnx'), - 'RN50x64': ('RN50x64/textual.onnx', 'RN50x64/visual.onnx'), - 'ViT-B/32': ('ViT-B-32/textual.onnx', 'ViT-B-32/visual.onnx'), - 'ViT-B/16': ('ViT-B-16/textual.onnx', 'ViT-B-16/visual.onnx'), - 'ViT-L/14': ('ViT-L-14/textual.onnx', 'ViT-L-14/visual.onnx'), - 'ViT-L/14@336px': ('ViT-L-14@336px/textual.onnx', 'ViT-L-14@336px/visual.onnx'), + 'RN50': ( + ('RN50/textual.onnx', '722418bfe47a1f5c79d1f44884bb3103'), + ('RN50/visual.onnx', '5761475db01c3abb68a5a805662dcd10'), + ), + 'RN101': ( + ('RN101/textual.onnx', '2d9efb7d184c0d68a369024cedfa97af'), + ('RN101/visual.onnx', '0297ebc773af312faab54f8b5a622d71'), + ), + 'RN50x4': ( + ('RN50x4/textual.onnx', 'd9d63d3fe35fb14d4affaa2c4e284005'), + ('RN50x4/visual.onnx', '16afe1e35b85ad862e8bbdb12265c9cb'), + ), + 'RN50x16': ( + ('RN50x16/textual.onnx', '1525785494ff5307cadc6bfa56db6274'), + ('RN50x16/visual.onnx', '2a293d9c3582f8abe29c9999e47d1091'), + ), + 'RN50x64': ( + ('RN50x64/textual.onnx', '3ae8ade74578eb7a77506c11bfbfaf2c'), + ('RN50x64/visual.onnx', '1341f10b50b3aca6d2d5d13982cabcfc'), + ), + 'ViT-B/32': ( + ('ViT-B-32/textual.onnx', 'bd6d7871e8bb95f3cc83aff3398d7390'), + ('ViT-B-32/visual.onnx', '88c6f38e522269d6c04a85df18e6370c'), + ), + 'ViT-B/16': ( + ('ViT-B-16/textual.onnx', '6f0976629a446f95c0c8767658f12ebe'), + ('ViT-B-16/visual.onnx', 'd5c03bfeef1abbd9bede54a8f6e1eaad'), + ), + 'ViT-L/14': ( + ('ViT-L-14/textual.onnx', '325380b31af4837c2e0d9aba2fad8e1b'), + ('ViT-L-14/visual.onnx', '53f5b319d3dc5d42572adea884e31056'), + ), + 'ViT-L/14@336px': ( + ('ViT-L-14@336px/textual.onnx', '78fab479f136403eed0db46f3e9e7ed2'), + ('ViT-L-14@336px/visual.onnx', 'f3b1f5d55ca08d43d749e11f7e4ba27e'), + ), } class CLIPOnnxModel: - def __init__(self, name: str = None): + def __init__(self, name: str = None, model_path: str = None): if name in _MODELS: - cache_dir = os.path.expanduser(f'~/.cache/clip/{name.replace("/", "-")}') - self._textual_path = _download( - _S3_BUCKET + _MODELS[name][0], cache_dir, with_resume=True - ) - self._visual_path = _download( - _S3_BUCKET + _MODELS[name][1], cache_dir, with_resume=True - ) + if not model_path: + cache_dir = os.path.expanduser( + f'~/.cache/clip/{name.replace("/", "-")}' + ) + textual_model_name, textual_model_md5 = _MODELS[name][0] + self._textual_path = _download( + url=_S3_BUCKET_V2 + textual_model_name, + target_folder=cache_dir, + md5sum=textual_model_md5, + with_resume=True, + ) + visual_model_name, visual_model_md5 = _MODELS[name][1] + self._visual_path = _download( + url=_S3_BUCKET_V2 + visual_model_name, + target_folder=cache_dir, + md5sum=visual_model_md5, + with_resume=True, + ) + else: + if os.path.isdir(model_path): + self._textual_path = os.path.join(model_path, 'textual.onnx') + self._visual_path = os.path.join(model_path, 'visual.onnx') + if not os.path.isfile(self._textual_path) or not os.path.isfile( + self._visual_path + ): + raise RuntimeError( + f'The given model path {model_path} does not contain `textual.onnx` and `visual.onnx`' + ) + else: + raise RuntimeError( + f'The given model path {model_path} is not a valid directory' + ) else: raise RuntimeError( f'Model {name} not found; available models = {available_models()}' @@ -44,11 +99,9 @@ def start_sessions( self._textual_session.disable_fallback() def encode_image(self, onnx_image): - onnx_input_image = {self._visual_session.get_inputs()[0].name: onnx_image} - (visual_output,) = self._visual_session.run(None, onnx_input_image) + (visual_output,) = self._visual_session.run(None, onnx_image) return visual_output def encode_text(self, onnx_text): - onnx_input_text = {self._textual_session.get_inputs()[0].name: onnx_text} - (textual_output,) = self._textual_session.run(None, onnx_input_text) + (textual_output,) = self._textual_session.run(None, onnx_text) return textual_output diff --git a/server/clip_server/model/clip_trt.py b/server/clip_server/model/clip_trt.py index c1e945a2a..b4803281e 100644 --- a/server/clip_server/model/clip_trt.py +++ b/server/clip_server/model/clip_trt.py @@ -13,19 +13,20 @@ "https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html" ) -from clip_server.model.clip import _download, MODEL_SIZE - -_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/tensorrt/' -_MODELS = { - 'RN50': ('RN50/textual.trt', 'RN50/visual.trt'), - 'RN101': ('RN101/textual.trt', 'RN101/visual.trt'), - 'RN50x4': ('RN50x4/textual.trt', 'RN50x4/visual.trt'), - # 'RN50x16': ('RN50x16/textual.trt', 'RN50x16/visual.trt'), - # 'RN50x64': ('RN50x64/textual.trt', 'RN50x64/visual.trt'), - 'ViT-B/32': ('ViT-B-32/textual.trt', 'ViT-B-32/visual.trt'), - 'ViT-B/16': ('ViT-B-16/textual.trt', 'ViT-B-16/visual.trt'), - 'ViT-L/14': ('ViT-L-14/textual.trt', 'ViT-L-14/visual.trt'), -} +from clip_server.model.clip import MODEL_SIZE +from clip_server.model.clip_onnx import _MODELS as ONNX_MODELS + +_MODELS = [ + 'RN50', + 'RN101', + 'RN50x4', + # 'RN50x16', + # 'RN50x64', + 'ViT-B/32', + 'ViT-B/16', + # 'ViT-L/14', + # 'ViT-L/14@336px', +] class CLIPTensorRTModel: @@ -34,77 +35,77 @@ def __init__( name: str = None, ): if name in _MODELS: + self._name = name cache_dir = os.path.expanduser(f'~/.cache/clip/{name.replace("/", "-")}') - self._textual_path = _download(_S3_BUCKET + _MODELS[name][0], cache_dir) - self._visual_path = _download(_S3_BUCKET + _MODELS[name][1], cache_dir) + + self._textual_path = os.path.join( + cache_dir, + f'textual.{ONNX_MODELS[name][0][1]}.trt', + ) + self._visual_path = os.path.join( + cache_dir, + f'visual.{ONNX_MODELS[name][1][1]}.trt', + ) + + if not os.path.exists(self._textual_path) or not os.path.exists( + self._visual_path + ): + from clip_server.model.clip_onnx import CLIPOnnxModel + + trt_logger: Logger = trt.Logger(trt.Logger.ERROR) + runtime: Runtime = trt.Runtime(trt_logger) + onnx_model = CLIPOnnxModel(self._name) + + visual_engine = build_engine( + runtime=runtime, + onnx_file_path=onnx_model._visual_path, + logger=trt_logger, + min_shape=(1, 3, MODEL_SIZE[self._name], MODEL_SIZE[self._name]), + optimal_shape=( + 768, + 3, + MODEL_SIZE[self._name], + MODEL_SIZE[self._name], + ), + max_shape=( + 1024, + 3, + MODEL_SIZE[self._name], + MODEL_SIZE[self._name], + ), + workspace_size=10000 * 1024 * 1024, + fp16=False, + int8=False, + ) + save_engine(visual_engine, self._visual_path) + + text_engine = build_engine( + runtime=runtime, + onnx_file_path=onnx_model._textual_path, + logger=trt_logger, + min_shape=(1, 77), + optimal_shape=(768, 77), + max_shape=(1024, 77), + workspace_size=10000 * 1024 * 1024, + fp16=False, + int8=False, + ) + save_engine(text_engine, self._textual_path) else: raise RuntimeError( f'Model {name} not found or not supports Nvidia TensorRT backend; available models = {list(_MODELS.keys())}' ) - self._name = name def start_engines(self): - import torch - trt_logger: Logger = trt.Logger(trt.Logger.ERROR) runtime: Runtime = trt.Runtime(trt_logger) - compute_capacity = torch.cuda.get_device_capability() - - if compute_capacity != (8, 6): - print( - f'The engine plan file is generated on an incompatible device, expecting compute {compute_capacity} ' - 'got compute 8.6, will rebuild the TensorRT engine.' - ) - from clip_server.model.clip_onnx import CLIPOnnxModel - - onnx_model = CLIPOnnxModel(self._name) - - visual_engine = build_engine( - runtime=runtime, - onnx_file_path=onnx_model._visual_path, - logger=trt_logger, - min_shape=(1, 3, MODEL_SIZE[self._name], MODEL_SIZE[self._name]), - optimal_shape=( - 768, - 3, - MODEL_SIZE[self._name], - MODEL_SIZE[self._name], - ), - max_shape=( - 1024, - 3, - MODEL_SIZE[self._name], - MODEL_SIZE[self._name], - ), - workspace_size=10000 * 1024 * 1024, - fp16=False, - int8=False, - ) - - save_engine(visual_engine, self._visual_path) - - text_engine = build_engine( - runtime=runtime, - onnx_file_path=onnx_model._textual_path, - logger=trt_logger, - min_shape=(1, 77), - optimal_shape=(768, 77), - max_shape=(1024, 77), - workspace_size=10000 * 1024 * 1024, - fp16=False, - int8=False, - ) - save_engine(text_engine, self._textual_path) - self._textual_engine = load_engine(runtime, self._textual_path) self._visual_engine = load_engine(runtime, self._visual_path) def encode_image(self, onnx_image): - (visual_output,) = self._visual_engine({'input': onnx_image}) - + (visual_output,) = self._visual_engine(onnx_image) return visual_output def encode_text(self, onnx_text): - (textual_output,) = self._textual_engine({'input': onnx_text}) - + (textual_output,) = self._textual_engine(onnx_text) return textual_output diff --git a/server/clip_server/tensorrt-flow.yml b/server/clip_server/tensorrt-flow.yml index cbe765091..6934c9993 100644 --- a/server/clip_server/tensorrt-flow.yml +++ b/server/clip_server/tensorrt-flow.yml @@ -9,4 +9,5 @@ executors: metas: py_modules: - executors/clip_tensorrt.py + timeout_ready: 3000000 replicas: 1 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index f5972b0dd..cc7feea03 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,16 +16,26 @@ def random_port(): return random_port -@pytest.fixture(scope='session', params=['onnx', 'torch', 'hg']) +@pytest.fixture(scope='session', params=['onnx', 'torch', 'hg', 'onnx_custom']) def make_flow(port_generator, request): - if request.param == 'onnx': - from clip_server.executors.clip_onnx import CLIPEncoder - elif request.param == 'torch': - from clip_server.executors.clip_torch import CLIPEncoder + if request.param != 'onnx_custom': + if request.param == 'onnx': + from clip_server.executors.clip_onnx import CLIPEncoder + elif request.param == 'torch': + from clip_server.executors.clip_torch import CLIPEncoder + else: + from clip_server.executors.clip_hg import CLIPEncoder + + f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder) else: - from clip_server.executors.clip_hg import CLIPEncoder + import os + from clip_server.executors.clip_onnx import CLIPEncoder - f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder) + f = Flow(port=port_generator()).add( + name=request.param, + uses=CLIPEncoder, + uses_with={'model_path': os.path.expanduser('~/.cache/clip/ViT-B-32')}, + ) with f: yield f diff --git a/tests/test_server.py b/tests/test_server.py index c1d99748d..a9b476149 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,12 +3,17 @@ import pytest from clip_server.model.clip import _transform_ndarray, _transform_blob, _download from docarray import Document +from jina import Flow import numpy as np def test_server_download(tmpdir): - _download('https://docarray.jina.ai/_static/favicon.png', tmpdir, with_resume=False) - + _download( + url='https://docarray.jina.ai/_static/favicon.png', + target_folder=tmpdir, + md5sum='a084999188f4290e2654aec43207ff2e', + with_resume=False, + ) target_path = os.path.join(tmpdir, 'favicon.png') file_size = os.path.getsize(target_path) assert file_size > 0 @@ -20,11 +25,84 @@ def test_server_download(tmpdir): os.remove(target_path) - _download('https://docarray.jina.ai/_static/favicon.png', tmpdir, with_resume=True) + _download( + url='https://docarray.jina.ai/_static/favicon.png', + target_folder=tmpdir, + md5sum='a084999188f4290e2654aec43207ff2e', + with_resume=True, + ) assert os.path.getsize(target_path) == file_size assert not os.path.exists(part_path) +@pytest.mark.parametrize('md5', ['ABC', None, 'a084999188f4290e2654aec43207ff2e']) +def test_server_download_md5(tmpdir, md5): + if md5 != 'ABC': + _download( + url='https://docarray.jina.ai/_static/favicon.png', + target_folder=tmpdir, + md5sum=md5, + with_resume=False, + ) + else: + with pytest.raises(Exception): + _download( + url='https://docarray.jina.ai/_static/favicon.png', + target_folder=tmpdir, + md5sum=md5, + with_resume=False, + ) + + +def test_server_download_not_regular_file(tmpdir): + with pytest.raises(Exception): + _download( + url='https://docarray.jina.ai/_static/favicon.png', + target_folder=tmpdir, + md5sum='', + with_resume=False, + ) + _download( + url='https://docarray.jina.ai/_static/', + target_folder=tmpdir, + md5sum='', + with_resume=False, + ) + + +def test_make_onnx_flow_custom_path_wrong_name(port_generator): + from clip_server.executors.clip_onnx import CLIPEncoder + + f = Flow(port=port_generator()).add( + name='onnx', + uses=CLIPEncoder, + uses_with={ + 'name': 'ABC', + 'model_path': os.path.expanduser('~/.cache/clip/ViT-B-32'), + }, + ) + with pytest.raises(Exception) as info: + with f: + f.post('/', Document(text='Hello world')) + + +@pytest.mark.parametrize('path', ['ABC', os.path.expanduser('~/.cache/')]) +def test_make_onnx_flow_custom_path_wrong_path(port_generator, path): + from clip_server.executors.clip_onnx import CLIPEncoder + + f = Flow(port=port_generator()).add( + name='onnx', + uses=CLIPEncoder, + uses_with={ + 'name': 'ViT-B/32', + 'model_path': path, + }, + ) + with pytest.raises(Exception) as info: + with f: + f.post('/', Document(text='Hello world')) + + @pytest.mark.parametrize( 'image_uri', [