Skip to content

Commit

Permalink
fix: use custom openclip visual model load in mclip
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiniuYu committed Sep 23, 2022
1 parent f5de308 commit ae046d1
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions server/clip_server/model/mclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import transformers
import torch
import open_clip

from clip_server.model.clip_model import CLIPModel
from clip_server.model.openclip_model import OpenCLIPModel

_CLIP_MODEL_MAPS = {
'M-CLIP/XLM-Roberta-Large-Vit-B-32': ('ViT-B-32', 'openai'),
'M-CLIP/XLM-Roberta-Large-Vit-L-14': ('ViT-L-14', 'openai'),
'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': ('ViT-B-16-plus-240', 'laion400m_e31'),
'M-CLIP/LABSE-Vit-L-14': ('ViT-L-14', 'openai'),
'M-CLIP/XLM-Roberta-Large-Vit-B-32': 'ViT-B-32::openai',
'M-CLIP/XLM-Roberta-Large-Vit-L-14': 'ViT-L-14::openai',
'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': 'ViT-B-16-plus-240::laion400m_e31',
'M-CLIP/LABSE-Vit-L-14': 'ViT-L-14::openai',
}


Expand Down Expand Up @@ -56,18 +56,11 @@ def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs):
self._mclip_model = MultilingualCLIP.from_pretrained(name)
self._mclip_model.to(device=device)
self._mclip_model.eval()
self._model = OpenCLIPModel(_CLIP_MODEL_MAPS[name], device=device, jit=jit)

clip_name, clip_pretrained = _CLIP_MODEL_MAPS[name]
self._model = open_clip.create_model(
clip_name, pretrained=clip_pretrained, device=device, jit=jit
)
self._model.eval()

self._clip_name = clip_name

@property
def model_name(self):
return self._clip_name
@staticmethod
def get_model_name(name: str):
return _CLIP_MODEL_MAPS[name].split('::')[0]

def encode_text(
self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', **kwargs
Expand Down

0 comments on commit ae046d1

Please sign in to comment.