Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dynamic convert onnx model to fp16 during start session #876

Merged
merged 14 commits into from
Dec 12, 2022
4 changes: 3 additions & 1 deletion server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(
sess_options.inter_op_num_threads = 1
sess_options.intra_op_num_threads = max(num_threads, 1)

self._model.start_sessions(sess_options=sess_options, providers=providers)
self._model.start_sessions(
sess_options=sess_options, providers=providers, dtype=dtype
)

if not self.tracer:
self.tracer = NoOpTracer()
Expand Down
62 changes: 26 additions & 36 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,22 +240,6 @@ def __init__(
f'The given model path {model_path} should be a folder containing both '
f'`textual.onnx` and `visual.onnx`.'
)
if dtype == 'fp16':
import onnx
from onnxmltools.utils import float16_converter

_textual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._textual_path
)
)
_visual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._visual_path
)
)
onnx.save_model(_textual_model_fp16, self._textual_path)
onnx.save_model(_visual_model_fp16, self._visual_path)
else:
raise RuntimeError(
'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format(
Expand All @@ -279,33 +263,39 @@ def get_model_name(name: str):

def start_sessions(
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
self,
dtype,
**kwargs,
):
import onnxruntime as ort

def _load_session_from_zip(model_path: str, model_type: str):
"""Load a model from a zip file."""
import zipfile
import tempfile
def _load_session(model_path: str, model_type: str, dtype: str):
if model_path.endswith('.zip') or dtype == 'fp16':
import tempfile

with zipfile.ZipFile(
model_path, 'r'
) as zip_ref, tempfile.TemporaryDirectory() as tmp_dir:
zip_ref.extractall(tmp_dir)
return ort.InferenceSession(tmp_dir + f'/{model_type}.onnx', **kwargs)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_model_path = tmp_dir + f'/{model_type}.onnx'
if model_path.endswith('.zip'):
import zipfile

if self._visual_path.endswith('.zip'):
self._visual_session = _load_session_from_zip(self._visual_path, 'visual')
else:
self._visual_session = ort.InferenceSession(self._visual_path, **kwargs)
self._visual_session.disable_fallback()
with zipfile.ZipFile(model_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
model_path = tmp_model_path
if dtype == 'fp16':
import onnx
from onnxmltools.utils import float16_converter

if self._textual_path.endswith('.zip'):
self._textual_session = _load_session_from_zip(
self._textual_path, 'textual'
)
else:
self._textual_session = ort.InferenceSession(self._textual_path, **kwargs)
model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
model_path
)
)
onnx.save_model(model_fp16, tmp_model_path)
return ort.InferenceSession(tmp_model_path, **kwargs)
return ort.InferenceSession(model_path, **kwargs)

self._visual_session = _load_session(self._visual_path, 'visual', dtype)
self._textual_session = _load_session(self._textual_path, 'textual', dtype)
self._visual_session.disable_fallback()
self._textual_session.disable_fallback()

def encode_image(self, image_input: Dict):
Expand Down
35 changes: 35 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,41 @@ def make_flow(port_generator, request):
yield f


@pytest.fixture(
scope='session',
params=[
['onnx', 'ViT-B-32::openai'],
['onnx', 'ViT-H-14::laion2b-s32b-b79k'],
['torch', 'ViT-B-32::openai'],
['onnx_custom', 'ViT-B-32::openai'],
],
)
def make_flow_with_large(port_generator, request):
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
if request.param != 'onnx_custom':
if request.param[0] == 'onnx':
from clip_server.executors.clip_onnx import CLIPEncoder
else:
from clip_server.executors.clip_torch import CLIPEncoder

f = Flow(port=port_generator()).add(
name=request.param[0],
uses=CLIPEncoder,
uses_with={'name': request.param[1]},
)
else:
import os
from clip_server.executors.clip_onnx import CLIPEncoder

model_name = request.param[1].replace('::', '-')
f = Flow(port=port_generator()).add(
name=request.param[0],
uses=CLIPEncoder,
uses_with={'model_path': os.path.expanduser(f'~/.cache/clip/{model_name}')},
)
with f:
yield f


@pytest.fixture(scope='session', params=['torch'])
def make_torch_flow(port_generator, request):
from clip_server.executors.clip_torch import CLIPEncoder
Expand Down
8 changes: 4 additions & 4 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_protocols(port_generator, protocol, jit, pytestconfig):
],
],
)
def test_plain_inputs(make_flow, inputs):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
def test_plain_inputs(make_flow_with_large, inputs):
c = Client(server=f'grpc://0.0.0.0:{make_flow_with_large.port}')
r = c.encode(inputs if not callable(inputs) else inputs())
assert (
r.shape[0] == len(list(inputs)) if not callable(inputs) else len(list(inputs()))
Expand Down Expand Up @@ -73,8 +73,8 @@ def test_plain_inputs(make_flow, inputs):
),
],
)
def test_docarray_inputs(make_flow, inputs):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
def test_docarray_inputs(make_flow_with_large, inputs):
c = Client(server=f'grpc://0.0.0.0:{make_flow_with_large.port}')
r = c.encode(inputs if not callable(inputs) else inputs())
assert isinstance(r, DocumentArray)
assert r.embeddings.shape
Expand Down