From bd0440223844809c5282f53ef23677b209782833 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Tue, 26 Jul 2022 11:34:45 +0800 Subject: [PATCH] fix: trt runtime --- server/clip_server/model/clip_trt.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/server/clip_server/model/clip_trt.py b/server/clip_server/model/clip_trt.py index 40e24f828..4d53400ef 100644 --- a/server/clip_server/model/clip_trt.py +++ b/server/clip_server/model/clip_trt.py @@ -13,7 +13,10 @@ "Please find installation instruction on " "https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html" ) - +from clip_server.model.pretrained_models import ( + _OPENCLIP_MODELS, + _MULTILINGUALCLIP_MODELS, +) from clip_server.model.clip_model import BaseCLIPModel from clip_server.model.clip_onnx import _MODELS as ONNX_MODELS @@ -98,6 +101,19 @@ def __init__( f'Model {name} not found or not supports Nvidia TensorRT backend; available models = {list(_MODELS.keys())}' ) + @staticmethod + def get_model_name(name: str): + if name in _OPENCLIP_MODELS: + from clip_server.model.openclip_model import OpenCLIPModel + + return OpenCLIPModel.get_model_name(name) + elif name in _MULTILINGUALCLIP_MODELS: + from clip_server.model.mclip_model import MultilingualCLIPModel + + return MultilingualCLIPModel.get_model_name(name) + + return name + def start_engines(self): trt_logger: Logger = trt.Logger(trt.Logger.ERROR) runtime: Runtime = trt.Runtime(trt_logger)