diff --git a/server/clip_server/model/mclip_model.py b/server/clip_server/model/mclip_model.py index c5b9058d5..149de0482 100644 --- a/server/clip_server/model/mclip_model.py +++ b/server/clip_server/model/mclip_model.py @@ -6,9 +6,9 @@ from clip_server.model.clip_model import CLIPModel -corresponding_clip_models = { +_CLIP_MODEL_MAPS = { 'M-CLIP/XLM-Roberta-Large-Vit-B-32': ('ViT-B-32', 'openai'), - 'M-CLIP/XLM-Roberta-Large-Vi-L-14': ('ViT-L-14', 'openai'), + 'M-CLIP/XLM-Roberta-Large-Vit-L-14': ('ViT-L-14', 'openai'), 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': ('ViT-B-16-plus-240', 'laion400m_e31'), 'M-CLIP/LABSE-Vit-L-14': ('ViT-L-14', 'openai'), } @@ -54,11 +54,15 @@ class MultilingualCLIPModel(CLIPModel): def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): super().__init__(name, **kwargs) self._mclip_model = MultilingualCLIP.from_pretrained(name) + self._mclip_model.to(device=device) + self._mclip_model.eval() - clip_name, clip_pretrained = corresponding_clip_models[name] + clip_name, clip_pretrained = _CLIP_MODEL_MAPS[name] self._model = open_clip.create_model( clip_name, pretrained=clip_pretrained, device=device, jit=jit ) + self._model.eval() + self._clip_name = clip_name @property