diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7a64bfa6c..924207527 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -158,6 +158,7 @@ jobs: pip install -e "client/[test]" pip install -e "server/[tensorrt]" pip install -e "server/[onnx]" + pip install -e "server/[transformers]" { pip install -e "server/[flash-attn]" } || { @@ -170,6 +171,8 @@ jobs: -v -s -m "gpu" ./tests/test_tensorrt.py pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \ -v -s -m "gpu" ./tests/test_simple.py + pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \ + -v -s -m "gpu" ./tests/test_model.py echo "::set-output name=codecov_flag::cas" timeout-minutes: 30 env: diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index bb589f912..860c5deb3 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -103,7 +103,9 @@ def __init__( sess_options.inter_op_num_threads = 1 sess_options.intra_op_num_threads = max(num_threads, 1) - self._model.start_sessions(sess_options=sess_options, providers=providers) + self._model.start_sessions( + sess_options=sess_options, providers=providers, dtype=dtype + ) if not self.tracer: self.tracer = NoOpTracer() diff --git a/server/clip_server/model/clip_onnx.py b/server/clip_server/model/clip_onnx.py index 2d61307ad..90a9c0f05 100644 --- a/server/clip_server/model/clip_onnx.py +++ b/server/clip_server/model/clip_onnx.py @@ -240,22 +240,6 @@ def __init__( f'The given model path {model_path} should be a folder containing both ' f'`textual.onnx` and `visual.onnx`.' ) - if dtype == 'fp16': - import onnx - from onnxmltools.utils import float16_converter - - _textual_model_fp16 = ( - float16_converter.convert_float_to_float16_model_path( - self._textual_path - ) - ) - _visual_model_fp16 = ( - float16_converter.convert_float_to_float16_model_path( - self._visual_path - ) - ) - onnx.save_model(_textual_model_fp16, self._textual_path) - onnx.save_model(_visual_model_fp16, self._visual_path) else: raise RuntimeError( 'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format( @@ -279,33 +263,39 @@ def get_model_name(name: str): def start_sessions( self, + dtype, **kwargs, ): import onnxruntime as ort - def _load_session_from_zip(model_path: str, model_type: str): - """Load a model from a zip file.""" - import zipfile - import tempfile + def _load_session(model_path: str, model_type: str, dtype: str): + if model_path.endswith('.zip') or dtype == 'fp16': + import tempfile - with zipfile.ZipFile( - model_path, 'r' - ) as zip_ref, tempfile.TemporaryDirectory() as tmp_dir: - zip_ref.extractall(tmp_dir) - return ort.InferenceSession(tmp_dir + f'/{model_type}.onnx', **kwargs) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_model_path = tmp_dir + f'/{model_type}.onnx' + if model_path.endswith('.zip'): + import zipfile - if self._visual_path.endswith('.zip'): - self._visual_session = _load_session_from_zip(self._visual_path, 'visual') - else: - self._visual_session = ort.InferenceSession(self._visual_path, **kwargs) - self._visual_session.disable_fallback() + with zipfile.ZipFile(model_path, 'r') as zip_ref: + zip_ref.extractall(tmp_dir) + model_path = tmp_model_path + if dtype == 'fp16': + import onnx + from onnxmltools.utils import float16_converter - if self._textual_path.endswith('.zip'): - self._textual_session = _load_session_from_zip( - self._textual_path, 'textual' - ) - else: - self._textual_session = ort.InferenceSession(self._textual_path, **kwargs) + model_fp16 = ( + float16_converter.convert_float_to_float16_model_path( + model_path + ) + ) + onnx.save_model(model_fp16, tmp_model_path) + return ort.InferenceSession(tmp_model_path, **kwargs) + return ort.InferenceSession(model_path, **kwargs) + + self._visual_session = _load_session(self._visual_path, 'visual', dtype) + self._textual_session = _load_session(self._textual_path, 'textual', dtype) + self._visual_session.disable_fallback() self._textual_session.disable_fallback() def encode_image(self, image_input: Dict): diff --git a/tests/test_model.py b/tests/test_model.py index 053a1d2b0..f7e44177b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -29,3 +29,14 @@ def test_torch_model(name, model_cls): ) def test_onnx_model(name): CLIPOnnxModel(name) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'name', + ['ViT-H-14::laion2b-s32b-b79k'], +) +def test_large_onnx_model_fp16(name): + from clip_server.executors.clip_onnx import CLIPEncoder + + CLIPEncoder(name, dtype='fp16')