Skip to content

Commit

Permalink
feat: update openclip loader (#782)
Browse files Browse the repository at this point in the history
* feat: update openclip loader

to support independent download process and make precision adapted to device to solve VRAM issue

* fix: changes for comments

* fix: error

* fix: error

* fix: use openai loader

* fix: address comments

* fix: openclip compatable

Co-authored-by: numb3r3 <[email protected]>
  • Loading branch information
shan-mx and numb3r3 authored Jul 26, 2022
1 parent 5877207 commit f043b4d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 16 deletions.
2 changes: 1 addition & 1 deletion server/clip_server/model/mclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ def encode_text(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)

def encode_image(self, pixel_values: torch.Tensor, **kwargs):
def encode_image(self, pixel_values: torch.Tensor):
return self._model.encode_image(pixel_values)
50 changes: 37 additions & 13 deletions server/clip_server/model/openclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
# Ludwig Schmidt

from typing import TYPE_CHECKING
from copy import deepcopy
import torch

from clip_server.model.clip_model import CLIPModel
from clip_server.model.pretrained_models import get_model_url_md5, download_model
import open_clip
from open_clip.openai import load_openai_model

from open_clip.model import (
CLIP,
convert_weights_to_fp16,
)
from open_clip.factory import _MODEL_CONFIGS, load_state_dict, load_openai_model

if TYPE_CHECKING:
import torch
Expand All @@ -20,28 +26,46 @@ class OpenCLIPModel(CLIPModel):
def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs):
super().__init__(name, **kwargs)

if '::' in name:
model_name, pretrained = name.split('::')
else:
# default pretrained model is from openai
model_name = name
pretrained = 'openai'

self._model_name = model_name

model_url, md5sum = get_model_url_md5(name)
if model_url:
model_path = download_model(model_url, md5sum=md5sum)
model_path = download_model(model_url, md5sum=md5sum)
if pretrained.lower() == 'openai':
self._model = load_openai_model(model_path, device=device, jit=jit)
self._model_name = name.split('::')[0]
else:
model_name, pretrained = name.split('::')
self._model = open_clip.create_model(
model_name, pretrained=pretrained, device=device, jit=jit
)
self._model_name = model_name
if model_name in _MODEL_CONFIGS:
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
else:
raise RuntimeError(f'Model config for {model_name} not found.')

self._model = CLIP(**model_cfg)

state_dict = load_state_dict(model_path)
self._model.load_state_dict(state_dict, strict=True)

if str(device) == 'cuda':
convert_weights_to_fp16(self._model)
if jit:
self._model = torch.jit.script(self._model)

self._model.to(device=torch.device(device))
self._model.eval()

@property
def model_name(self):
if self._model_name == 'ViT-L/14@336px':
return 'ViT-L-14-336'
elif self._model_name.endswith('-quickgelu'):
return self._model_name[:-10]
return self._model_name.replace('/', '-')

def encode_text(self, input_ids: 'torch.Tensor', **kwargs):
return self._model.encode_text(input_ids)

def encode_image(self, pixel_values: 'torch.Tensor', **kwargs):
def encode_image(self, pixel_values: 'torch.Tensor'):
return self._model.encode_image(pixel_values)
1 change: 0 additions & 1 deletion server/clip_server/model/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
'ViT-B-32': 224,
'ViT-B-16': 224,
'ViT-B-16-plus-240': 240,
'ViT-B-16-plus-240': 240,
'ViT-L-14': 224,
'ViT-L-14-336': 336,
'Vit-B-16Plus': 240,
Expand Down
2 changes: 1 addition & 1 deletion server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
'torchvision',
'jina>=3.6.0',
'prometheus-client',
'open_clip_torch',
'open_clip_torch>=1.3.0',
],
extras_require={
'onnx': [
Expand Down

0 comments on commit f043b4d

Please sign in to comment.