Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dynamic convert onnx model to fp16 during start session #876

Merged
merged 14 commits into from
Dec 12, 2022
4 changes: 3 additions & 1 deletion server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(
sess_options.inter_op_num_threads = 1
sess_options.intra_op_num_threads = max(num_threads, 1)

self._model.start_sessions(sess_options=sess_options, providers=providers)
self._model.start_sessions(
sess_options=sess_options, providers=providers, dtype=dtype
)

if not self.tracer:
self.tracer = NoOpTracer()
Expand Down
62 changes: 26 additions & 36 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,22 +240,6 @@ def __init__(
f'The given model path {model_path} should be a folder containing both '
f'`textual.onnx` and `visual.onnx`.'
)
if dtype == 'fp16':
import onnx
from onnxmltools.utils import float16_converter

_textual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._textual_path
)
)
_visual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._visual_path
)
)
onnx.save_model(_textual_model_fp16, self._textual_path)
onnx.save_model(_visual_model_fp16, self._visual_path)
else:
raise RuntimeError(
'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format(
Expand All @@ -279,33 +263,39 @@ def get_model_name(name: str):

def start_sessions(
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
self,
dtype,
**kwargs,
):
import onnxruntime as ort

def _load_session_from_zip(model_path: str, model_type: str):
"""Load a model from a zip file."""
import zipfile
import tempfile
def _load_session(model_path: str, model_type: str, dtype: str):
if model_path.endswith('.zip') or dtype == 'fp16':
import tempfile

with zipfile.ZipFile(
model_path, 'r'
) as zip_ref, tempfile.TemporaryDirectory() as tmp_dir:
zip_ref.extractall(tmp_dir)
return ort.InferenceSession(tmp_dir + f'/{model_type}.onnx', **kwargs)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_model_path = tmp_dir + f'/{model_type}.onnx'
if model_path.endswith('.zip'):
import zipfile

if self._visual_path.endswith('.zip'):
self._visual_session = _load_session_from_zip(self._visual_path, 'visual')
else:
self._visual_session = ort.InferenceSession(self._visual_path, **kwargs)
self._visual_session.disable_fallback()
with zipfile.ZipFile(model_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
model_path = tmp_model_path
if dtype == 'fp16':
import onnx
from onnxmltools.utils import float16_converter

if self._textual_path.endswith('.zip'):
self._textual_session = _load_session_from_zip(
self._textual_path, 'textual'
)
else:
self._textual_session = ort.InferenceSession(self._textual_path, **kwargs)
model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
model_path
)
)
onnx.save_model(model_fp16, tmp_model_path)
return ort.InferenceSession(tmp_model_path, **kwargs)
return ort.InferenceSession(model_path, **kwargs)

self._visual_session = _load_session(self._visual_path, 'visual', dtype)
self._textual_session = _load_session(self._textual_path, 'textual', dtype)
self._visual_session.disable_fallback()
self._textual_session.disable_fallback()

def encode_image(self, image_input: Dict):
Expand Down