Skip to content

Commit

Permalink
fix: trt runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
numb3r3 committed Jul 26, 2022
1 parent ddf86cf commit bd04402
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion server/clip_server/model/clip_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
"Please find installation instruction on "
"https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html"
)

from clip_server.model.pretrained_models import (
_OPENCLIP_MODELS,
_MULTILINGUALCLIP_MODELS,
)
from clip_server.model.clip_model import BaseCLIPModel
from clip_server.model.clip_onnx import _MODELS as ONNX_MODELS

Expand Down Expand Up @@ -98,6 +101,19 @@ def __init__(
f'Model {name} not found or not supports Nvidia TensorRT backend; available models = {list(_MODELS.keys())}'
)

@staticmethod
def get_model_name(name: str):
if name in _OPENCLIP_MODELS:
from clip_server.model.openclip_model import OpenCLIPModel

return OpenCLIPModel.get_model_name(name)
elif name in _MULTILINGUALCLIP_MODELS:
from clip_server.model.mclip_model import MultilingualCLIPModel

return MultilingualCLIPModel.get_model_name(name)

return name

def start_engines(self):
trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
Expand Down

0 comments on commit bd04402

Please sign in to comment.