Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dynamic convert onnx model to fp16 during start session #876

Merged
merged 14 commits into from
Dec 12, 2022
4 changes: 3 additions & 1 deletion server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(
sess_options.inter_op_num_threads = 1
sess_options.intra_op_num_threads = max(num_threads, 1)

self._model.start_sessions(sess_options=sess_options, providers=providers)
self._model.start_sessions(
sess_options=sess_options, providers=providers, dtype=dtype
)

if not self.tracer:
self.tracer = NoOpTracer()
Expand Down
64 changes: 28 additions & 36 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,22 +240,6 @@ def __init__(
f'The given model path {model_path} should be a folder containing both '
f'`textual.onnx` and `visual.onnx`.'
)
if dtype == 'fp16':
import onnx
from onnxmltools.utils import float16_converter

_textual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._textual_path
)
)
_visual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._visual_path
)
)
onnx.save_model(_textual_model_fp16, self._textual_path)
onnx.save_model(_visual_model_fp16, self._visual_path)
else:
raise RuntimeError(
'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format(
Expand All @@ -279,33 +263,41 @@ def get_model_name(name: str):

def start_sessions(
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
self,
dtype,
**kwargs,
):
import tempfile
import onnxruntime as ort

def _load_session_from_zip(model_path: str, model_type: str):
"""Load a model from a zip file."""
import zipfile
import tempfile
def _load_session_as_fp16(src_model: str, tmp_model: str, model_type: str):
import onnx
from onnxmltools.utils import float16_converter
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved

with zipfile.ZipFile(
model_path, 'r'
) as zip_ref, tempfile.TemporaryDirectory() as tmp_dir:
zip_ref.extractall(tmp_dir)
return ort.InferenceSession(tmp_dir + f'/{model_type}.onnx', **kwargs)
_model_fp16 = float16_converter.convert_float_to_float16_model_path(
src_model
)
onnx.save_model(_model_fp16, tmp_model)
return ort.InferenceSession(tmp_model, **kwargs)
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved

if self._visual_path.endswith('.zip'):
self._visual_session = _load_session_from_zip(self._visual_path, 'visual')
else:
self._visual_session = ort.InferenceSession(self._visual_path, **kwargs)
self._visual_session.disable_fallback()
def _load_session(model_path: str, model_type: str, dtype: str):
if model_path.endswith('.zip') or dtype == 'fp16':
with tempfile.TemporaryDirectory() as tmp_dir:
src_model = model_path
tmp_model = tmp_dir + f'/{model_type}.onnx'
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
if model_path.endswith('.zip'):
import zipfile

if self._textual_path.endswith('.zip'):
self._textual_session = _load_session_from_zip(
self._textual_path, 'textual'
)
else:
self._textual_session = ort.InferenceSession(self._textual_path, **kwargs)
with zipfile.ZipFile(model_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
src_model = tmp_model
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
if dtype == 'fp16':
return _load_session_as_fp16(src_model, tmp_model, model_type)
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
return ort.InferenceSession(tmp_model, **kwargs)
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
return ort.InferenceSession(model_path, **kwargs)

self._visual_session = _load_session(self._visual_path, 'visual', dtype)
self._textual_session = _load_session(self._textual_path, 'textual', dtype)
self._visual_session.disable_fallback()
self._textual_session.disable_fallback()

def encode_image(self, image_input: Dict):
Expand Down