diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index ed6cafe82..31305b570 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -127,17 +127,65 @@ jobs:
with:
files: "coverage.xml"
- name: Upload coverage from test to Codecov
- uses: codecov/codecov-action@v2
+ uses: codecov/codecov-action@v3
if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.7'
with:
file: coverage.xml
+ name: ${{ matrix.test-path }}-codecov
+ flags: ${{ steps.test.outputs.codecov_flag }}
+ fail_ci_if_error: false
+ token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
+
+ gpu-test:
+ needs: prep-testbed
+ runs-on: [self-hosted, x64, gpu, linux]
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: [ 3.7 ]
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ # For coverage builds fetch the whole history
+ fetch-depth: 0
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Prepare enviroment
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install wheel pytest pytest-cov nvidia-pyindex
+ pip install -e "client/[test]"
+ pip install -e "server/[tensorrt]"
+ - name: Test
+ id: test
+ run: |
+ pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
+ -v -s -m "gpu" ./tests/test_tensorrt.py
+ echo "::set-output name=codecov_flag::cas"
+ timeout-minutes: 30
+ env:
+ # fix re-initialized torch runtime error on cuda device
+ JINA_MP_START_METHOD: spawn
+ - name: Check codecov file
+ id: check_files
+ uses: andstor/file-existence-action@v1
+ with:
+ files: "coverage.xml"
+ - name: Upload coverage from test to Codecov
+ uses: codecov/codecov-action@v3
+ if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.7'
+ with:
+ file: coverage.xml
+ name: gpu-related-codecov
flags: ${{ steps.test.outputs.codecov_flag }}
fail_ci_if_error: false
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
# just for blocking the merge until all parallel core-test are successful
success-all-test:
- needs: core-test
+ needs: [core-test, gpu-test]
if: always()
runs-on: ubuntu-latest
steps:
diff --git a/README.md b/README.md
index 2af71c73b..63c885fac 100644
--- a/README.md
+++ b/README.md
@@ -19,7 +19,7 @@
CLIP-as-service is a low-latency high-scalability service for embedding images and text. It can be easily integrated as a microservice into neural search solutions.
-⚡ **Fast**: Serve CLIP models with ONNX runtime and PyTorch JIT with 800QPS[*]. Non-blocking duplex streaming on requests and responses, designed for large data and long-running tasks.
+⚡ **Fast**: Serve CLIP models with TensorRT, ONNX runtime and PyTorch JIT with 800QPS[*]. Non-blocking duplex streaming on requests and responses, designed for large data and long-running tasks.
🫐 **Elastic**: Horizontally scale up and down multiple CLIP models on single GPU, with automatic load balancing.
@@ -58,6 +58,16 @@ To run CLIP model via ONNX (default is via PyTorch):
pip install "clip-server[onnx]"
```
+To run CLIP model via TensorRT
+
+```bash
+# You must first install the nvidia-pyindex package, which is required in order to set up your pip installation
+# to fetch additional Python modules from the NVIDIA NGC™ PyPI repo.
+pip install nvidia-pyindex
+
+pip install "clip-server[tensorrt]"
+```
+
### Install client
```bash
diff --git a/docs/user-guides/server.md b/docs/user-guides/server.md
index d7c113035..e34760fca 100644
--- a/docs/user-guides/server.md
+++ b/docs/user-guides/server.md
@@ -2,7 +2,7 @@
CLIP-as-service is designed in a client-server architecture. A server is a long-running program that receives raw sentences and images from clients, and returns CLIP embeddings to the client. Additionally, `clip_server` is optimized for speed, low memory footprint and scalability.
- Horizontal scaling: adding more replicas easily with one argument.
-- Vertical scaling: using PyTorch JIT or ONNX runtime to speedup single GPU inference.
+- Vertical scaling: using PyTorch JIT, ONNX or TensorRT runtime to speedup single GPU inference.
- Supporting gRPC, HTTP, Websocket protocols with their TLS counterparts, w/o compressions.
This chapter introduces the API of the client.
@@ -34,12 +34,26 @@ To use ONNX runtime for CLIP, you can run:
```bash
pip install "clip_server[onnx]"
-python -m clip_server onnx_flow.yml
+python -m clip_server onnx-flow.yml
```
-One may wonder where is this `onnx_flow.yml` come from. Must be a typo? Believe me, just run it. It should work. I will explain this YAML file in the next section.
+We also support TensorRT runtime for CLIP, you can run:
+
+```bash
+# You must first install the nvidia-pyindex package, which is required in order to set up your pip installation
+# to fetch additional Python modules from the NVIDIA NGC™ PyPI repo.
+pip install nvidia-pyindex
+
+pip install "clip_server[tensorrt]"
+
+python -m clip_server tensorrt-flow.yml
+```
+
+One may wonder where is this `onnx-flow.yml` (or `tensorrt-flow.yml`) come from. Must be a typo? Believe me, just run it. It should work. I will explain this YAML file in the next section.
-The procedure and UI of ONNX runtime would look the same as Pytorch runtime.
+
+
+The procedure and UI of ONNX and TensorRT runtime would look the same as Pytorch runtime.
@@ -47,9 +61,10 @@ The procedure and UI of ONNX runtime would look the same as Pytorch runtime.
You may notice that there is a YAML file in our last ONNX example. All configurations are stored in this file. In fact, `python -m clip_server` does **not support** any other argument besides a YAML file. So it is the only source of the truth of your configs.
-And to answer your doubt, `clip_server` has two built-in YAML configs as a part of the package resources: one for PyTorch backend, one for ONNX backend. When you do `python -m clip_server` it loads the Pytorch config, and when you do `python -m clip_server onnx-flow.yml` it loads the ONNX config.
+And to answer your doubt, `clip_server` has three built-in YAML configs as a part of the package resources. When you do `python -m clip_server` it loads the Pytorch config, and when you do `python -m clip_server onnx-flow.yml` it loads the ONNX config.
+In the same way, when you do `python -m clip_server tensorrt-flow.yml` it loads the TensorRT config.
-Let's look at these two built-in YAML configs:
+Let's look at these three built-in YAML configs:
````{tab} torch-flow.yml
@@ -85,6 +100,24 @@ executors:
```
````
+
+````{tab} tensorrt-flow.yml
+
+```yaml
+jtype: Flow
+version: '1'
+with:
+ port: 51000
+executors:
+ - name: clip_r
+ uses:
+ jtype: CLIPEncoder
+ metas:
+ py_modules:
+ - executors/clip_trt.py
+```
+````
+
Basically, each YAML file defines a [Jina Flow](https://docs.jina.ai/fundamentals/flow/). The complete Jina Flow YAML syntax [can be found here](https://docs.jina.ai/fundamentals/flow/flow-yaml/#configure-flow-meta-information). General parameters of the Flow and Executor can be used here as well. But now we only highlight the most important parameters.
Looking at the YAML file again, we can put it into three subsections as below:
@@ -162,7 +195,7 @@ executors:
### CLIP model config
-For PyTorch & ONNX backend, you can set the following parameters via `with`:
+For all backends, you can set the following parameters via `with`:
| Parameter | Description |
|-----------|--------------------------------------------------------------------------------------------------------------------------------|
diff --git a/scripts/benchmark.py b/scripts/benchmark.py
index 4598edc8d..ffd3595eb 100644
--- a/scripts/benchmark.py
+++ b/scripts/benchmark.py
@@ -16,6 +16,9 @@ def warn(*args, **kwargs):
warnings.warn = warn
+np.random.seed(123)
+
+
class BenchmarkClient(threading.Thread):
def __init__(
self,
diff --git a/scripts/get-all-test-paths.sh b/scripts/get-all-test-paths.sh
index d58fb8552..e3eafa64b 100755
--- a/scripts/get-all-test-paths.sh
+++ b/scripts/get-all-test-paths.sh
@@ -6,7 +6,7 @@ BATCH_SIZE=3
#declare -a array1=( "tests/unit/test_*.py" )
#declare -a array2=( $(ls -d tests/unit/*/ | grep -v '__pycache__' | grep -v 'array') )
#declare -a array3=( "tests/unit/array/*.py" )
-declare -a mixins=( $(find tests -name "test_*.py") )
+declare -a mixins=( $(find tests -name "test_*.py" | grep -v 'test_tensorrt.py') )
declare -a array4=( "$(echo "${mixins[@]}" | xargs -n$BATCH_SIZE)" )
# array5 is currently empty because in the array/ directory, mixins is the only directory
# but add the following in case new directories are created in array/
diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py
index ab9984cc2..c7861aa14 100644
--- a/server/clip_server/executors/clip_onnx.py
+++ b/server/clip_server/executors/clip_onnx.py
@@ -1,26 +1,15 @@
import os
import warnings
-from multiprocessing.pool import ThreadPool, Pool
-from typing import List, Tuple, Optional
-import numpy as np
+from functools import partial
+from multiprocessing.pool import ThreadPool
+from typing import Optional
import onnxruntime as ort
from jina import Executor, requests, DocumentArray
from clip_server.model import clip
from clip_server.model.clip_onnx import CLIPOnnxModel
-
-_SIZE = {
- 'RN50': 224,
- 'RN101': 224,
- 'RN50x4': 288,
- 'RN50x16': 384,
- 'RN50x64': 448,
- 'ViT-B/32': 224,
- 'ViT-B/16': 224,
- 'ViT-L/14': 224,
- 'ViT-L/14@336px': 336,
-}
+from clip_server.executors.helper import split_img_txt_da, preproc_image, preproc_text
class CLIPEncoder(Executor):
@@ -34,8 +23,7 @@ def __init__(
):
super().__init__(**kwargs)
- self._preprocess_blob = clip._transform_blob(_SIZE[name])
- self._preprocess_tensor = clip._transform_ndarray(_SIZE[name])
+ self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._pool = ThreadPool(processes=num_worker_preprocess)
self._minibatch_size = minibatch_size
@@ -86,51 +74,30 @@ def __init__(
self._model.start_sessions(sess_options=sess_options, providers=providers)
- def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray':
- for d in da:
- if d.tensor is not None:
- d.tensor = self._preprocess_tensor(d.tensor)
- else:
- if not d.blob and d.uri:
- # in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri
- d.load_uri_to_blob()
- d.tensor = self._preprocess_blob(d.blob)
- da.tensors = da.tensors.detach().cpu().numpy().astype(np.float32)
- return da
-
- def _preproc_text(self, da: 'DocumentArray') -> Tuple['DocumentArray', List[str]]:
- texts = da.texts
- da.tensors = clip.tokenize(texts).detach().cpu().numpy().astype(np.int64)
- da[:, 'mime_type'] = 'text'
- return da, texts
-
@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
- if d.text:
- _txt_da.append(d)
- elif (d.blob is not None) or (d.tensor is not None):
- _img_da.append(d)
- elif d.uri:
- _img_da.append(d)
- else:
- warnings.warn(
- f'The content of document {d.id} is empty, cannot be processed'
- )
+ split_img_txt_da(d, _img_da, _txt_da)
# for image
if _img_da:
for minibatch in _img_da.map_batch(
- self._preproc_image, batch_size=self._minibatch_size, pool=self._pool
+ partial(
+ preproc_image, preprocess_fn=self._preprocess_tensor, return_np=True
+ ),
+ batch_size=self._minibatch_size,
+ pool=self._pool,
):
minibatch.embeddings = self._model.encode_image(minibatch.tensors)
# for text
if _txt_da:
for minibatch, _texts in _txt_da.map_batch(
- self._preproc_text, batch_size=self._minibatch_size, pool=self._pool
+ partial(preproc_text, return_np=True),
+ batch_size=self._minibatch_size,
+ pool=self._pool,
):
minibatch.embeddings = self._model.encode_text(minibatch.tensors)
minibatch.texts = _texts
diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py
index b2c2569c3..5f77f8e70 100644
--- a/server/clip_server/executors/clip_torch.py
+++ b/server/clip_server/executors/clip_torch.py
@@ -1,11 +1,15 @@
import os
import warnings
+from functools import partial
+
from multiprocessing.pool import ThreadPool
from typing import Optional, List, Tuple, Dict
import numpy as np
import torch
from clip_server.model import clip
+from clip_server.executors.helper import split_img_txt_da, preproc_image, preproc_text
+
from jina import Executor, requests, DocumentArray
@@ -46,39 +50,12 @@ def __init__(
torch.set_num_interop_threads(1)
self._minibatch_size = minibatch_size
- self._model, self._preprocess_blob, self._preprocess_tensor = clip.load(
+ self._model, self._preprocess_tensor = clip.load(
name, device=self._device, jit=jit
)
self._pool = ThreadPool(processes=num_worker_preprocess)
- def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray':
- for d in da:
- if d.tensor is not None:
- d.tensor = self._preprocess_tensor(d.tensor)
- else:
- if not d.blob and d.uri:
- # in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri
- d.load_uri_to_blob()
- d.tensor = self._preprocess_blob(d.blob)
- da.tensors = da.tensors.to(self._device)
- return da
-
- def _preproc_text(self, da: 'DocumentArray') -> Tuple['DocumentArray', List[str]]:
- texts = da.texts
- da.tensors = clip.tokenize(texts).to(self._device)
- da[:, 'mime_type'] = 'text'
- return da, texts
-
- @staticmethod
- def _split_img_txt_da(d, _img_da, _txt_da):
- if d.text:
- _txt_da.append(d)
- elif (d.blob is not None) or (d.tensor is not None):
- _img_da.append(d)
- elif d.uri:
- _img_da.append(d)
-
@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
import torch
@@ -89,10 +66,10 @@ async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
for d in docs:
_img_da = DocumentArray()
_txt_da = DocumentArray()
- self._split_img_txt_da(d, _img_da, _txt_da)
+ split_img_txt_da(d, _img_da, _txt_da)
for c in _get(d):
- self._split_img_txt_da(c, _img_da, _txt_da)
+ split_img_txt_da(c, _img_da, _txt_da)
if len(_img_da) != 1 and len(_txt_da) != 1:
raise ValueError(
@@ -156,13 +133,18 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
- self._split_img_txt_da(d, _img_da, _txt_da)
+ split_img_txt_da(d, _img_da, _txt_da)
with torch.inference_mode():
# for image
if _img_da:
for minibatch in _img_da.map_batch(
- self._preproc_image,
+ partial(
+ preproc_image,
+ preprocess_fn=self._preprocess_tensor,
+ device=self._device,
+ return_np=False,
+ ),
batch_size=self._minibatch_size,
pool=self._pool,
):
@@ -176,7 +158,7 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
# for text
if _txt_da:
for minibatch, _texts in _txt_da.map_batch(
- self._preproc_text,
+ partial(preproc_text, device=self._device, return_np=False),
batch_size=self._minibatch_size,
pool=self._pool,
):
diff --git a/server/clip_server/executors/clip_trt.py b/server/clip_server/executors/clip_trt.py
new file mode 100644
index 000000000..866b044db
--- /dev/null
+++ b/server/clip_server/executors/clip_trt.py
@@ -0,0 +1,89 @@
+from multiprocessing.pool import ThreadPool
+from functools import partial
+import numpy as np
+from jina import Executor, requests, DocumentArray
+from jina.logging.logger import JinaLogger
+
+from clip_server.model import clip
+from clip_server.model.clip_trt import CLIPTensorRTModel
+from clip_server.executors.helper import split_img_txt_da, preproc_image, preproc_text
+
+
+class CLIPEncoder(Executor):
+ def __init__(
+ self,
+ name: str = 'ViT-B/32',
+ device: str = 'cuda',
+ num_worker_preprocess: int = 4,
+ minibatch_size: int = 64,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.logger = JinaLogger(self.__class__.__name__)
+
+ self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name])
+ self._pool = ThreadPool(processes=num_worker_preprocess)
+
+ self._minibatch_size = minibatch_size
+ self._device = device
+
+ import torch
+
+ assert self._device.startswith('cuda'), (
+ f'can not perform inference on {self._device}'
+ f' with Nvidia TensorRT as backend'
+ )
+
+ assert (
+ torch.cuda.is_available()
+ ), "CUDA/GPU is not available on Pytorch. Please check your CUDA installation"
+
+ self._model = CLIPTensorRTModel(name)
+
+ self._model.start_engines()
+
+ @requests
+ async def encode(self, docs: 'DocumentArray', **kwargs):
+ _img_da = DocumentArray()
+ _txt_da = DocumentArray()
+ for d in docs:
+ split_img_txt_da(d, _img_da, _txt_da)
+
+ # for image
+ if _img_da:
+ for minibatch in _img_da.map_batch(
+ partial(
+ preproc_image,
+ preprocess_fn=self._preprocess_tensor,
+ device=self._device,
+ return_np=False,
+ ),
+ batch_size=self._minibatch_size,
+ pool=self._pool,
+ ):
+ minibatch.embeddings = (
+ self._model.encode_image(minibatch.tensors)
+ .detach()
+ .cpu()
+ .numpy()
+ .astype(np.float32)
+ )
+
+ # for text
+ if _txt_da:
+ for minibatch, _texts in _txt_da.map_batch(
+ partial(preproc_text, device=self._device, return_np=False),
+ batch_size=self._minibatch_size,
+ pool=self._pool,
+ ):
+ minibatch.embeddings = (
+ self._model.encode_text(minibatch.tensors)
+ .detach()
+ .cpu()
+ .numpy()
+ .astype(np.float32)
+ )
+ minibatch.texts = _texts
+
+ # drop tensors
+ docs.tensors = None
diff --git a/server/clip_server/executors/helper.py b/server/clip_server/executors/helper.py
new file mode 100644
index 000000000..226e9e144
--- /dev/null
+++ b/server/clip_server/executors/helper.py
@@ -0,0 +1,52 @@
+from typing import Tuple, List, Callable, TYPE_CHECKING
+import numpy as np
+from clip_server.model import clip
+
+if TYPE_CHECKING:
+ from docarray import Document, DocumentArray
+
+
+def preproc_image(
+ da: 'DocumentArray',
+ preprocess_fn: Callable,
+ device: str = 'cpu',
+ return_np: bool = False,
+) -> 'DocumentArray':
+ for d in da:
+ if d.blob:
+ d.convert_blob_to_image_tensor()
+ elif d.tensor is None and d.uri:
+ # in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri
+ d.load_uri_to_image_tensor()
+
+ d.tensor = preprocess_fn(d.tensor).detach()
+
+ if return_np:
+ da.tensors = da.tensors.cpu().numpy().astype(np.float32)
+ else:
+ da.tensors = da.tensors.to(device)
+ return da
+
+
+def preproc_text(
+ da: 'DocumentArray', device: str = 'cpu', return_np: bool = False
+) -> Tuple['DocumentArray', List[str]]:
+ texts = da.texts
+ da.tensors = clip.tokenize(texts).detach()
+
+ if return_np:
+ da.tensors = da.tensors.cpu().numpy().astype(np.int64)
+ else:
+ da.tensors = da.tensors.to(device)
+
+ da[:, 'mime_type'] = 'text'
+ return da, texts
+
+
+def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'):
+ if doc.text:
+ txt_da.append(doc)
+ elif doc.blob or (doc.tensor is not None):
+ img_da.append(doc)
+ elif doc.uri:
+ img_da.append(doc)
diff --git a/server/clip_server/model/clip.py b/server/clip_server/model/clip.py
index 157736104..35a01e498 100644
--- a/server/clip_server/model/clip.py
+++ b/server/clip_server/model/clip.py
@@ -37,6 +37,18 @@
'ViT-L/14@336px': 'ViT-L-14-336px.pt',
}
+MODEL_SIZE = {
+ 'RN50': 224,
+ 'RN101': 224,
+ 'RN50x4': 288,
+ 'RN50x16': 384,
+ 'RN50x64': 448,
+ 'ViT-B/32': 224,
+ 'ViT-B/16': 224,
+ 'ViT-L/14': 224,
+ 'ViT-L/14@336px': 336,
+}
+
def _download(url: str, root: str, with_resume: bool = True):
os.makedirs(root, exist_ok=True)
@@ -223,7 +235,6 @@ def load(
model.float()
return (
model,
- _transform_blob(model.visual.input_resolution),
_transform_ndarray(model.visual.input_resolution),
)
@@ -292,7 +303,6 @@ def patch_float(module):
return (
model,
- _transform_blob(model.input_resolution.item()),
_transform_ndarray(model.input_resolution.item()),
)
diff --git a/server/clip_server/model/clip_onnx.py b/server/clip_server/model/clip_onnx.py
index 649f704fe..aa022e19a 100644
--- a/server/clip_server/model/clip_onnx.py
+++ b/server/clip_server/model/clip_onnx.py
@@ -1,7 +1,5 @@
import os
-import onnxruntime as ort
-
from .clip import _download, available_models
_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/'
@@ -37,6 +35,8 @@ def start_sessions(
self,
**kwargs,
):
+ import onnxruntime as ort
+
self._visual_session = ort.InferenceSession(self._visual_path, **kwargs)
self._visual_session.disable_fallback()
diff --git a/server/clip_server/model/clip_trt.py b/server/clip_server/model/clip_trt.py
new file mode 100644
index 000000000..0ae5f6b7c
--- /dev/null
+++ b/server/clip_server/model/clip_trt.py
@@ -0,0 +1,110 @@
+import os
+
+try:
+ import tensorrt as trt
+ from tensorrt.tensorrt import Logger, Runtime
+
+ from clip_server.model.trt_utils import load_engine, build_engine, save_engine
+except ImportError:
+ raise ImportError(
+ "It seems that TensorRT is not yet installed. "
+ "It is required when you declare TensorRT backend."
+ "Please find installation instruction on "
+ "https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html"
+ )
+
+from .clip import _download, MODEL_SIZE
+
+_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/tensorrt/'
+_MODELS = {
+ 'RN50': ('RN50/textual.trt', 'RN50/visual.trt'),
+ 'RN101': ('RN101/textual.trt', 'RN101/visual.trt'),
+ 'RN50x4': ('RN50x4/textual.trt', 'RN50x4/visual.trt'),
+ # 'RN50x16': ('RN50x16/textual.trt', 'RN50x16/visual.trt'),
+ # 'RN50x64': ('RN50x64/textual.trt', 'RN50x64/visual.trt'),
+ 'ViT-B/32': ('ViT-B-32/textual.trt', 'ViT-B-32/visual.trt'),
+ 'ViT-B/16': ('ViT-B-16/textual.trt', 'ViT-B-16/visual.trt'),
+ 'ViT-L/14': ('ViT-L-14/textual.trt', 'ViT-L-14/visual.trt'),
+}
+
+
+class CLIPTensorRTModel:
+ def __init__(
+ self,
+ name: str = None,
+ ):
+ if name in _MODELS:
+ cache_dir = os.path.expanduser(f'~/.cache/clip/{name.replace("/", "-")}')
+ self._textual_path = _download(_S3_BUCKET + _MODELS[name][0], cache_dir)
+ self._visual_path = _download(_S3_BUCKET + _MODELS[name][1], cache_dir)
+ else:
+ raise RuntimeError(
+ f'Model {name} not found or not supports Nvidia TensorRT backend; available models = {list(_MODELS.keys())}'
+ )
+ self._name = name
+
+ def start_engines(self):
+ import torch
+
+ trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
+ runtime: Runtime = trt.Runtime(trt_logger)
+ compute_capacity = torch.cuda.get_device_capability()
+
+ if compute_capacity != (8, 6):
+ print(
+ f'The engine plan file is generated on an incompatible device, expecting compute {compute_capacity} '
+ 'got compute 8.6, will rebuild the TensorRT engine.'
+ )
+ from clip_server.model.clip_onnx import CLIPOnnxModel
+
+ onnx_model = CLIPOnnxModel(self._name)
+
+ visual_engine = build_engine(
+ runtime=runtime,
+ onnx_file_path=onnx_model._visual_path,
+ logger=trt_logger,
+ min_shape=(1, 3, MODEL_SIZE[self._name], MODEL_SIZE[self._name]),
+ optimal_shape=(
+ 768,
+ 3,
+ MODEL_SIZE[self._name],
+ MODEL_SIZE[self._name],
+ ),
+ max_shape=(
+ 1024,
+ 3,
+ MODEL_SIZE[self._name],
+ MODEL_SIZE[self._name],
+ ),
+ workspace_size=10000 * 1024 * 1024,
+ fp16=False,
+ int8=False,
+ )
+
+ save_engine(visual_engine, self._visual_path)
+
+ text_engine = build_engine(
+ runtime=runtime,
+ onnx_file_path=onnx_model._textual_path,
+ logger=trt_logger,
+ min_shape=(1, 77),
+ optimal_shape=(768, 77),
+ max_shape=(1024, 77),
+ workspace_size=10000 * 1024 * 1024,
+ fp16=False,
+ int8=False,
+ )
+ save_engine(text_engine, self._textual_path)
+
+ self._textual_engine = load_engine(runtime, self._textual_path)
+ self._visual_engine = load_engine(runtime, self._visual_path)
+
+ def encode_image(self, onnx_image):
+ (visual_output,) = self._visual_engine({'input': onnx_image})
+
+ return visual_output
+
+ def encode_text(self, onnx_text):
+ (textual_output,) = self._textual_engine({'input': onnx_text})
+
+ return textual_output
diff --git a/server/clip_server/model/trt_utils.py b/server/clip_server/model/trt_utils.py
new file mode 100644
index 000000000..351311f48
--- /dev/null
+++ b/server/clip_server/model/trt_utils.py
@@ -0,0 +1,259 @@
+# Originally from https://github.com/ELS-RD/transformer-deploy.
+# Apache License, Version 2.0, Copyright (c) 2022 Lefebvre Dalloz Services
+
+from typing import Callable, Dict, List, OrderedDict, Tuple
+
+import tensorrt as trt
+import torch
+from tensorrt import ICudaEngine, IExecutionContext
+from tensorrt.tensorrt import (
+ Builder,
+ IBuilderConfig,
+ IElementWiseLayer,
+ ILayer,
+ INetworkDefinition,
+ IOptimizationProfile,
+ IReduceLayer,
+ Logger,
+ OnnxParser,
+ Runtime,
+)
+
+
+"""
+All the tooling to ease TensorRT usage.
+"""
+
+
+def fix_fp16_network(network_definition: INetworkDefinition) -> INetworkDefinition:
+ """
+ Mixed precision on TensorRT can generate scores very far from Pytorch because of some operator being saturated.
+ Indeed, FP16 can't store very large and very small numbers like FP32.
+ Here, we search for some patterns of operators to keep in FP32, in most cases, it is enough to fix the inference
+ and don't hurt performances.
+ :param network_definition: graph generated by TensorRT after parsing ONNX file (during the model building)
+ :return: patched network definition
+ """
+ # search for patterns which may overflow in FP16 precision, we force FP32 precisions for those nodes
+ for layer_index in range(network_definition.num_layers - 1):
+ layer: ILayer = network_definition.get_layer(layer_index)
+ next_layer: ILayer = network_definition.get_layer(layer_index + 1)
+ # POW operation usually followed by mean reduce
+ if (
+ layer.type == trt.LayerType.ELEMENTWISE
+ and next_layer.type == trt.LayerType.REDUCE
+ ):
+ # casting to get access to op attribute
+ layer.__class__ = IElementWiseLayer
+ next_layer.__class__ = IReduceLayer
+ if layer.op == trt.ElementWiseOperation.POW:
+ layer.precision = trt.DataType.FLOAT
+ next_layer.precision = trt.DataType.FLOAT
+ layer.set_output_type(index=0, dtype=trt.DataType.FLOAT)
+ next_layer.set_output_type(index=0, dtype=trt.DataType.FLOAT)
+ return network_definition
+
+
+def build_engine(
+ runtime: Runtime,
+ onnx_file_path: str,
+ logger: Logger,
+ min_shape: Tuple[int, int],
+ optimal_shape: Tuple[int, int],
+ max_shape: Tuple[int, int],
+ workspace_size: int,
+ fp16: bool,
+ int8: bool,
+) -> ICudaEngine:
+ """
+ Convert ONNX file to TensorRT engine.
+ It supports dynamic shape, however it's advised to keep sequence length fix as it hurts performance otherwise.
+ Dynamic batch size don't hurt performance and is highly advised.
+ :param runtime: global variable shared accross inference call / model building
+ :param onnx_file_path: path to the ONNX file
+ :param logger: specific logger to TensorRT
+ :param min_shape: the minimal shape of input tensors. It's advised to set first dimension (batch size) to 1
+ :param optimal_shape: input tensor shape used for optimizations
+ :param max_shape: maximal input tensor shape
+ :param workspace_size: GPU memory to use during the building, more is always better. If there is not enough memory,
+ some optimization may fail, and the whole conversion process will crash.
+ :param fp16: enable FP16 precision, it usually provide a 20-30% boost compared to ONNX Runtime.
+ :param int8: enable INT-8 quantization, best performance but model should have been quantized.
+ :return: TensorRT engine to use during inference
+ """
+ with trt.Builder(logger) as builder: # type: Builder
+ with builder.create_network(
+ flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+ ) as network_definition: # type: INetworkDefinition
+ with trt.OnnxParser(
+ network_definition, logger
+ ) as parser: # type: OnnxParser
+ builder.max_batch_size = max_shape[0] # max batch size
+ config: IBuilderConfig = builder.create_builder_config()
+ config.max_workspace_size = workspace_size
+ # to enable complete trt inspector debugging, only for TensorRT >= 8.2
+ config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
+ # disable CUDNN optimizations
+ config.set_tactic_sources(
+ tactic_sources=1 << int(trt.TacticSource.CUBLAS)
+ | 1 << int(trt.TacticSource.CUBLAS_LT)
+ )
+ if int8:
+ config.set_flag(trt.BuilderFlag.INT8)
+ if fp16:
+ config.set_flag(trt.BuilderFlag.FP16)
+ config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE)
+ # https://github.com/NVIDIA/TensorRT/issues/1196 (sometimes big diff in output when using FP16)
+ config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
+ with open(onnx_file_path, "rb") as f:
+ parser.parse(f.read())
+ profile: IOptimizationProfile = builder.create_optimization_profile()
+ for num_input in range(network_definition.num_inputs):
+ profile.set_shape(
+ input=network_definition.get_input(num_input).name,
+ min=min_shape,
+ opt=optimal_shape,
+ max=max_shape,
+ )
+ config.add_optimization_profile(profile)
+ if fp16:
+ network_definition = fix_fp16_network(network_definition)
+ trt_engine = builder.build_serialized_network(
+ network_definition, config
+ )
+ engine: ICudaEngine = runtime.deserialize_cuda_engine(trt_engine)
+ assert (
+ engine is not None
+ ), "error during engine generation, check error messages above :-("
+ return engine
+
+
+def get_output_tensors(
+ context: trt.IExecutionContext,
+ host_inputs: List[torch.Tensor],
+ input_binding_idxs: List[int],
+ output_binding_idxs: List[int],
+) -> List[torch.Tensor]:
+ """
+ Reserve memory in GPU for input and output tensors.
+ :param context: TensorRT context shared accross inference steps
+ :param host_inputs: input tensor
+ :param input_binding_idxs: indexes of each input vector (should be the same than during building)
+ :param output_binding_idxs: indexes of each output vector (should be the same than during building)
+ :return: tensors where output will be stored
+ """
+ # explicitly set dynamic input shapes, so dynamic output shapes can be computed internally
+ for host_input, binding_index in zip(host_inputs, input_binding_idxs):
+ context.set_binding_shape(binding_index, tuple(host_input.shape))
+ assert context.all_binding_shapes_specified
+ device_outputs: List[torch.Tensor] = []
+ for binding_index in output_binding_idxs:
+ # TensorRT computes output shape based on input shape provided above
+ output_shape = context.get_binding_shape(binding_index)
+ # allocate buffers to hold output results
+ output = torch.empty(tuple(output_shape), device="cuda")
+ device_outputs.append(output)
+ return device_outputs
+
+
+def infer_tensorrt(
+ context: IExecutionContext,
+ host_inputs: OrderedDict[str, torch.Tensor],
+ input_binding_idxs: List[int],
+ output_binding_idxs: List[int],
+) -> List[torch.Tensor]:
+ """
+ Perform inference with TensorRT.
+ :param context: shared variable
+ :param host_inputs: input tensor
+ :param input_binding_idxs: input tensor indexes
+ :param output_binding_idxs: output tensor indexes
+ :return: output tensor
+ """
+ input_tensors: List[torch.Tensor] = list()
+ for tensor in host_inputs.values():
+ assert isinstance(
+ tensor, torch.Tensor
+ ), f"unexpected tensor type: {tensor.dtype}"
+
+ if tensor.dtype == torch.int64:
+ # warning: small changes in output if int64 is used instead of int32
+ tensor = tensor.type(torch.int32)
+ # tensor = tensor.to("cuda")
+ input_tensors.append(tensor)
+ # calculate input shape, bind it, allocate GPU memory for the output
+ output_tensors: List[torch.Tensor] = get_output_tensors(
+ context, input_tensors, input_binding_idxs, output_binding_idxs
+ )
+ bindings = [int(i.data_ptr()) for i in input_tensors + output_tensors]
+ assert context.execute_async_v2(
+ bindings, torch.cuda.current_stream().cuda_stream
+ ), "failure during execution of inference"
+ torch.cuda.current_stream().synchronize() # sync all CUDA ops
+ return output_tensors
+
+
+def load_engine(
+ runtime: Runtime, engine_file_path: str, profile_index: int = 0
+) -> Callable[[Dict[str, torch.Tensor]], torch.Tensor]:
+ """
+ Load serialized TensorRT engine.
+ :param runtime: shared variable
+ :param engine_file_path: path to the serialized engine
+ :param profile_index: which profile to load, 0 if you have not used multiple profiles
+ :return: A function to perform inference
+ """
+ with open(file=engine_file_path, mode="rb") as f:
+ engine: ICudaEngine = runtime.deserialize_cuda_engine(f.read())
+ stream: int = torch.cuda.current_stream().cuda_stream
+ context: IExecutionContext = engine.create_execution_context()
+ context.set_optimization_profile_async(
+ profile_index=profile_index, stream_handle=stream
+ )
+ # retrieve input/output IDs
+ input_binding_idxs, output_binding_idxs = get_binding_idxs(
+ engine, profile_index
+ ) # type: List[int], List[int]
+
+ def tensorrt_model(inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
+ return infer_tensorrt(
+ context=context,
+ host_inputs=inputs,
+ input_binding_idxs=input_binding_idxs,
+ output_binding_idxs=output_binding_idxs,
+ )
+
+ return tensorrt_model
+
+
+def save_engine(engine: ICudaEngine, engine_file_path: str) -> None:
+ """
+ Serialize TensorRT engine to file.
+ :param engine: TensorRT engine
+ :param engine_file_path: output path
+ """
+ with open(engine_file_path, "wb") as f:
+ f.write(engine.serialize())
+
+
+def get_binding_idxs(engine: trt.ICudaEngine, profile_index: int):
+ """
+ Calculate start/end binding indices for current context's profile
+ https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#opt_profiles_bindings
+ :param engine: TensorRT engine generated during the model building
+ :param profile_index: profile to use (several profiles can be set during building)
+ :return: input and output tensor indexes
+ """
+ num_bindings_per_profile = engine.num_bindings // engine.num_optimization_profiles
+ start_binding = profile_index * num_bindings_per_profile
+ end_binding = (
+ start_binding + num_bindings_per_profile
+ ) # Separate input and output binding indices for convenience
+ input_binding_idxs: List[int] = []
+ output_binding_idxs: List[int] = []
+ for binding_index in range(start_binding, end_binding):
+ if engine.binding_is_input(binding_index):
+ input_binding_idxs.append(binding_index)
+ else:
+ output_binding_idxs.append(binding_index)
+ return input_binding_idxs, output_binding_idxs
diff --git a/server/clip_server/tensorrt-flow.yml b/server/clip_server/tensorrt-flow.yml
new file mode 100644
index 000000000..3f37763e1
--- /dev/null
+++ b/server/clip_server/tensorrt-flow.yml
@@ -0,0 +1,12 @@
+jtype: Flow
+version: '1'
+with:
+ port: 51000
+executors:
+ - name: clip_r
+ uses:
+ jtype: CLIPEncoder
+ metas:
+ py_modules:
+ - executors/clip_trt.py
+
diff --git a/server/setup.py b/server/setup.py
index f3cfc0364..9dc283592 100644
--- a/server/setup.py
+++ b/server/setup.py
@@ -56,6 +56,7 @@
'onnx',
]
+ (['onnxruntime-gpu>=1.8.0'] if sys.platform != 'darwin' else []),
+ 'tensorrt': ['nvidia-tensorrt'],
},
classifiers=[
'Development Status :: 5 - Production/Stable',
diff --git a/tests/conftest.py b/tests/conftest.py
index b7023b8a8..078d24985 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -35,3 +35,12 @@ def make_torch_flow(port_generator, request):
f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder)
with f:
yield f
+
+
+@pytest.fixture(scope='session', params=['torch'])
+def make_trt_flow(port_generator, request):
+ from clip_server.executors.clip_trt import CLIPEncoder
+
+ f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder)
+ with f:
+ yield f
diff --git a/tests/test_tensorrt.py b/tests/test_tensorrt.py
new file mode 100644
index 000000000..a06efe589
--- /dev/null
+++ b/tests/test_tensorrt.py
@@ -0,0 +1,38 @@
+import os
+
+import pytest
+from docarray import Document, DocumentArray
+from jina import Flow
+
+from clip_client.client import Client
+
+
+@pytest.mark.gpu
+@pytest.mark.parametrize(
+ 'inputs',
+ [
+ [Document(text='hello, world'), Document(text='goodbye, world')],
+ DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]),
+ lambda: (Document(text='hello, world') for _ in range(10)),
+ DocumentArray(
+ [
+ Document(uri='https://docarray.jina.ai/_static/favicon.png'),
+ Document(
+ uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
+ ),
+ Document(text='hello, world'),
+ Document(
+ uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
+ ).load_uri_to_image_tensor(),
+ ]
+ ),
+ DocumentArray.from_files(
+ f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
+ ),
+ ],
+)
+def test_docarray_inputs(make_trt_flow, inputs):
+ c = Client(server=f'grpc://0.0.0.0:{make_trt_flow.port}')
+ r = c.encode(inputs if not callable(inputs) else inputs())
+ assert isinstance(r, DocumentArray)
+ assert r.embeddings.shape