Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support onnx backend for openclip #781

Merged
merged 8 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self._model = CLIPOnnxModel(name, model_path)
self._tokenizer = Tokenizer(name)

self._image_transform = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._image_transform = clip._transform_ndarray(self._model.image_size)

import torch

Expand Down
2 changes: 1 addition & 1 deletion server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self._model.start_engines()

self._tokenizer = Tokenizer(name)
self._image_transform = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._image_transform = clip._transform_ndarray(self._model.image_size)

def _preproc_images(self, docs: 'DocumentArray'):
with self.monitor(
Expand Down
1 change: 0 additions & 1 deletion server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(

self._model = CLIPModel(name, device=self._device, jit=jit, **kwargs)
self._tokenizer = Tokenizer(name)

self._image_transform = clip._transform_ndarray(self._model.image_size)

def _preproc_images(self, docs: 'DocumentArray'):
Expand Down
31 changes: 19 additions & 12 deletions server/clip_server/model/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,25 @@
)


class CLIPModel:
class BaseCLIPModel:
def __init__(self, name: str, **kwargs):
super().__init__()
self._name = name

@staticmethod
def get_model_name(name: str):
return name

@property
def model_name(self):
return self.__class__.get_model_name(self._name)
Comment on lines +13 to +19
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we simply return xxx.name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To get the model name without initializing the instance (avoiding downloading and loading weights), we must use staticmethod . That's the reason I implement get_model_name


@property
def image_size(self):
return _VISUAL_MODEL_IMAGE_SIZE.get(self.model_name, None)


class CLIPModel(BaseCLIPModel):
def __new__(cls, name: str, **kwargs):
if cls is CLIPModel:
if name in _OPENCLIP_MODELS:
Expand All @@ -21,14 +39,3 @@ def __new__(cls, name: str, **kwargs):
else:
instance = super().__new__(cls)
return instance

def __init__(self, name: str, **kwargs):
self._name = name

@property
def model_name(self):
return self._name

@property
def image_size(self):
return _VISUAL_MODEL_IMAGE_SIZE.get(self.model_name, None)
48 changes: 30 additions & 18 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
from typing import Dict

from clip_server.model.clip import available_models
from clip_server.model.pretrained_models import download_model
from clip_server.model.pretrained_models import (
download_model,
_OPENCLIP_MODELS,
_MULTILINGUALCLIP_MODELS,
)
from clip_server.model.clip_model import BaseCLIPModel

_S3_BUCKET = (
'https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/' # Deprecated
Expand All @@ -14,16 +19,11 @@
),
'RN50::yfcc15m': (),
'RN50::cc12m': (),
'RN50-quickgelu::openai': (),
'RN50-quickgelu::yfcc15m': (),
'RN50-quickgelu::cc12m': (),
'RN101::openai': (
('RN101/textual.onnx', '2d9efb7d184c0d68a369024cedfa97af'),
('RN101/visual.onnx', '0297ebc773af312faab54f8b5a622d71'),
),
'RN101::yfcc15m': (),
'RN101-quickgelu::openai': (),
'RN101-quickgelu::yfcc15m': (),
'RN50x4::openai': (
('RN50x4/textual.onnx', 'd9d63d3fe35fb14d4affaa2c4e284005'),
('RN50x4/visual.onnx', '16afe1e35b85ad862e8bbdb12265c9cb'),
Expand All @@ -43,9 +43,6 @@
'ViT-B-32::laion2b_e16': (),
'ViT-B-32::laion400m_e31': (),
'ViT-B-32::laion400m_e32': (),
'ViT-B-32-quickgelu::openai': (),
'ViT-B-32-quickgelu::laion400m_e31': (),
'ViT-B-32-quickgelu::laion400m_e32': (),
'ViT-B-16::openai': (
('ViT-B-16/textual.onnx', '6f0976629a446f95c0c8767658f12ebe'),
('ViT-B-16/visual.onnx', 'd5c03bfeef1abbd9bede54a8f6e1eaad'),
Expand Down Expand Up @@ -102,8 +99,9 @@
}


class CLIPOnnxModel:
def __init__(self, name: str = None, model_path: str = None):
class CLIPOnnxModel(BaseCLIPModel):
def __init__(self, name: str, model_path: str = None):
super().__init__(name)
if name in _MODELS:
if not model_path:
cache_dir = os.path.expanduser(
Expand Down Expand Up @@ -135,13 +133,27 @@ def __init__(self, name: str = None, model_path: str = None):
)
else:
raise RuntimeError(
f'The given model path {model_path} is not a valid directory'
f'The given model path {model_path} should be a folder containing both '
f'`textual.onnx` and `visual.onnx`.'
)
else:
raise RuntimeError(
f'Model {name} not found; available models = {available_models()}'
f'Model {name} not found; available models = {list(_MODELS.keys())}'
)

@staticmethod
def get_model_name(name: str):
if name in _OPENCLIP_MODELS:
from clip_server.model.openclip_model import OpenCLIPModel

return OpenCLIPModel.get_model_name(name)
elif name in _MULTILINGUALCLIP_MODELS:
from clip_server.model.mclip_model import MultilingualCLIPModel

return MultilingualCLIPModel.get_model_name(name)

return name

def start_sessions(
self,
**kwargs,
Expand All @@ -154,10 +166,10 @@ def start_sessions(
self._textual_session = ort.InferenceSession(self._textual_path, **kwargs)
self._textual_session.disable_fallback()

def encode_image(self, onnx_image):
(visual_output,) = self._visual_session.run(None, onnx_image)
def encode_image(self, image_input: Dict):
(visual_output,) = self._visual_session.run(None, image_input)
return visual_output

def encode_text(self, onnx_text):
(textual_output,) = self._textual_session.run(None, onnx_text)
def encode_text(self, text_input: Dict):
(textual_output,) = self._textual_session.run(None, text_input)
return textual_output
48 changes: 33 additions & 15 deletions server/clip_server/model/clip_trt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Dict

try:
import tensorrt as trt
Expand All @@ -12,8 +13,11 @@
"Please find installation instruction on "
"https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html"
)

from clip_server.model.clip import MODEL_SIZE
from clip_server.model.pretrained_models import (
_OPENCLIP_MODELS,
_MULTILINGUALCLIP_MODELS,
)
from clip_server.model.clip_model import BaseCLIPModel
from clip_server.model.clip_onnx import _MODELS as ONNX_MODELS

_MODELS = [
Expand All @@ -29,13 +33,14 @@
]


class CLIPTensorRTModel:
class CLIPTensorRTModel(BaseCLIPModel):
def __init__(
self,
name: str = None,
name: str,
):
super().__init__(name)

if name in _MODELS:
self._name = name
cache_dir = os.path.expanduser(f'~/.cache/clip/{name.replace("/", "-")}')

self._textual_path = os.path.join(
Expand All @@ -54,24 +59,24 @@ def __init__(

trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
onnx_model = CLIPOnnxModel(self._name)
onnx_model = CLIPOnnxModel(name)

visual_engine = build_engine(
runtime=runtime,
onnx_file_path=onnx_model._visual_path,
logger=trt_logger,
min_shape=(1, 3, MODEL_SIZE[self._name], MODEL_SIZE[self._name]),
min_shape=(1, 3, onnx_model.image_size, onnx_model.image_size),
optimal_shape=(
768,
3,
MODEL_SIZE[self._name],
MODEL_SIZE[self._name],
onnx_model.image_size,
onnx_model.image_size,
),
max_shape=(
1024,
3,
MODEL_SIZE[self._name],
MODEL_SIZE[self._name],
onnx_model.image_size,
onnx_model.image_size,
),
workspace_size=10000 * 1024 * 1024,
fp16=False,
Expand All @@ -96,16 +101,29 @@ def __init__(
f'Model {name} not found or not supports Nvidia TensorRT backend; available models = {list(_MODELS.keys())}'
)

@staticmethod
def get_model_name(name: str):
if name in _OPENCLIP_MODELS:
from clip_server.model.openclip_model import OpenCLIPModel

return OpenCLIPModel.get_model_name(name)
elif name in _MULTILINGUALCLIP_MODELS:
from clip_server.model.mclip_model import MultilingualCLIPModel

return MultilingualCLIPModel.get_model_name(name)

return name

def start_engines(self):
trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
self._textual_engine = load_engine(runtime, self._textual_path)
self._visual_engine = load_engine(runtime, self._visual_path)

def encode_image(self, onnx_image):
(visual_output,) = self._visual_engine(onnx_image)
def encode_image(self, image_input: Dict):
(visual_output,) = self._visual_engine(image_input)
return visual_output

def encode_text(self, onnx_text):
(textual_output,) = self._textual_engine(onnx_text)
def encode_text(self, text_input: Dict):
(textual_output,) = self._textual_engine(text_input)
return textual_output
4 changes: 2 additions & 2 deletions server/clip_server/model/mclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs):
self._clip_name = clip_name

@property
def image_size(self):
return _VISUAL_MODEL_IMAGE_SIZE[self._clip_name]
def model_name(self):
return self._clip_name

def encode_text(
self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', **kwargs
Expand Down
60 changes: 19 additions & 41 deletions server/clip_server/model/openclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@
# Ludwig Schmidt

from typing import TYPE_CHECKING
from copy import deepcopy
import torch

from clip_server.model.clip_model import CLIPModel
from clip_server.model.pretrained_models import get_model_url_md5, download_model

from open_clip.model import (
CLIP,
convert_weights_to_fp16,
)
from open_clip.factory import _MODEL_CONFIGS, load_state_dict, load_openai_model
import open_clip
from open_clip.openai import load_openai_model

if TYPE_CHECKING:
import torch
Expand All @@ -26,46 +20,30 @@ class OpenCLIPModel(CLIPModel):
def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs):
super().__init__(name, **kwargs)

if '::' in name:
model_name, pretrained = name.split('::')
else:
# default pretrained model is from openai
model_name = name
pretrained = 'openai'

self._model_name = model_name

model_url, md5sum = get_model_url_md5(name)
model_path = download_model(model_url, md5sum=md5sum)
if pretrained.lower() == 'openai':
if model_url:
model_path = download_model(model_url, md5sum=md5sum)
self._model = load_openai_model(model_path, device=device, jit=jit)
self._model_name = name
else:
if model_name in _MODEL_CONFIGS:
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
else:
raise RuntimeError(f'Model config for {model_name} not found.')

self._model = CLIP(**model_cfg)

state_dict = load_state_dict(model_path)
self._model.load_state_dict(state_dict, strict=True)

if str(device) == 'cuda':
convert_weights_to_fp16(self._model)
if jit:
self._model = torch.jit.script(self._model)

self._model.to(device=torch.device(device))
self._model.eval()
model_name, pretrained = name.split('::')
self._model = open_clip.create_model(
model_name, pretrained=pretrained, device=device, jit=jit
)
self._model_name = model_name

@property
def model_name(self):
if self._model_name == 'ViT-L/14@336px':
@staticmethod
def get_model_name(name: str):
if '::' in name:
model_name, pretrained = name.split('::')
else:
model_name = name
if model_name == 'ViT-L/14@336px':
return 'ViT-L-14-336'
return self._model_name.replace('/', '-')
return model_name.replace('/', '-')

def encode_text(self, input_ids: 'torch.Tensor', **kwargs):
return self._model.encode_text(input_ids)

def encode_image(self, pixel_values: 'torch.Tensor'):
def encode_image(self, pixel_values: 'torch.Tensor', **kwargs):
return self._model.encode_image(pixel_values)