Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cn clip model #888

Merged
merged 14 commits into from
Feb 9, 2023
1 change: 1 addition & 0 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install --no-cache-dir "server/[cn_clip]"
- name: Test
id: test
run: |
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install --no-cache-dir "server/[cn_clip]"
- name: Test
id: test
run: |
Expand Down Expand Up @@ -164,6 +165,7 @@ jobs:
} || {
echo "flash attention was not installed."
}
pip install --no-cache-dir "server/[cn_clip]"
- name: Test
id: test
run: |
Expand Down
6 changes: 6 additions & 0 deletions server/clip_server/model/clip_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from clip_server.model.pretrained_models import (
_OPENCLIP_MODELS,
_MULTILINGUALCLIP_MODELS,
_CNCLIP_MODELS,
_VISUAL_MODEL_IMAGE_SIZE,
)

Expand Down Expand Up @@ -34,6 +35,10 @@ def __new__(cls, name: str, **kwargs):
from clip_server.model.mclip_model import MultilingualCLIPModel

instance = super().__new__(MultilingualCLIPModel)
elif name in _CNCLIP_MODELS:
from clip_server.model.cnclip_model import CNClipModel

instance = super().__new__(CNClipModel)
else:
raise ValueError(
'CLIP model {} not found; below is a list of all available models:\n{}'.format(
Expand All @@ -43,6 +48,7 @@ def __new__(cls, name: str, **kwargs):
'\t- {}\n'.format(i)
for i in list(_OPENCLIP_MODELS.keys())
+ list(_MULTILINGUALCLIP_MODELS.keys())
+ list(_CNCLIP_MODELS.keys())
]
),
)
Expand Down
51 changes: 51 additions & 0 deletions server/clip_server/model/cnclip_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Originally from https://github.com/OFA-Sys/Chinese-CLIP. MIT License.

import torch

from clip_server.model.clip_model import CLIPModel
from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
from cn_clip.clip import load_from_name

_CNCLIP_MODEL_MAPS = {
'CN-CLIP/ViT-B-16': 'ViT-B-16',
'CN-CLIP/ViT-L-14': 'ViT-L-14',
'CN-CLIP/ViT-L-14-336': 'ViT-L-14-336',
'CN-CLIP/ViT-H-14': 'ViT-H-14',
'CN-CLIP/RN50': 'RN50',
}


class CNClipModel(CLIPModel):
def __init__(
self,
name: str,
device: str = 'cpu',
jit: bool = False,
dtype: str = None,
**kwargs
):
super().__init__(name, **kwargs)
self._name = _CNCLIP_MODEL_MAPS[name]

self._model, self._preprocess = load_from_name(
_CNCLIP_MODEL_MAPS[name], device=device
)
self._model.eval()

@staticmethod
def get_model_name(name: str):
return _CNCLIP_MODEL_MAPS[name]

def encode_text(self, input_ids: 'torch.Tensor', **kwargs):
return self._model.encode_text(input_ids).detach()

def encode_image(self, pixel_values: 'torch.Tensor', **kwargs):
return self._model.encode_image(pixel_values).detach()

@property
def model_name(self):
return self.__class__.get_model_name(self._name)

@property
def image_size(self):
return _VISUAL_MODEL_IMAGE_SIZE.get(self._name, None)
8 changes: 8 additions & 0 deletions server/clip_server/model/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@
'M-CLIP/LABSE-Vit-L-14': (),
}

_CNCLIP_MODELS = {
'CN-CLIP/ViT-B-16': (),
'CN-CLIP/ViT-L-14': (),
'CN-CLIP/ViT-L-14-336': (),
'CN-CLIP/ViT-H-14': (),
'CN-CLIP/RN50': (),
}

_VISUAL_MODEL_IMAGE_SIZE = {
'RN50': 224,
'RN101': 224,
Expand Down
30 changes: 27 additions & 3 deletions server/clip_server/model/tokenization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
from typing import List, Union
from clip_server.model.pretrained_models import _MULTILINGUALCLIP_MODELS
from clip_server.model.pretrained_models import (
_MULTILINGUALCLIP_MODELS,
_CNCLIP_MODELS,
)


class Tokenizer:
Expand All @@ -10,6 +13,10 @@ def __init__(self, name: str, **kwargs):
import transformers

self._tokenizer = transformers.AutoTokenizer.from_pretrained(name)
elif name in _CNCLIP_MODELS:
import cn_clip.clip as cnclip

self._tokenizer = cnclip
else:
from clip_server.model.simple_tokenizer import SimpleTokenizer

Expand All @@ -23,13 +30,19 @@ def __call__(
):
"""
:param texts: An input string or a list of input strings to tokenize
:param context_length: The context length to use; all CLIP models use 77 as the context length.
:param context_length: The context length to use; all English CLIP models use 77 as the context length.
for Chinese CLIP models, context_length = 52, if the number of characters is bigger than 50, sentence will be truncate and omit the part left
:param truncate: Whether to truncate the text in case its encoding is longer than the context length.

:return: A dict of tokenized representations of the input strings and their corresponding attention masks with both
shape = [batch size, context_length]
"""
return self._tokenize(texts, context_length=context_length, truncate=truncate)
if self._name in _CNCLIP_MODELS:
return self._tokenize(texts, context_length=52)
else:
return self._tokenize(
texts, context_length=context_length, truncate=truncate
)

def _tokenize(
self,
Expand All @@ -52,6 +65,17 @@ def _tokenize(
'input_ids': result['input_ids'],
'attention_mask': result['attention_mask'],
}
elif self._name in _CNCLIP_MODELS:
result = self._tokenizer.tokenize(
texts=texts,
context_length=52, # in all cnclip baseline model context length is 52
)
attn_mask = result.clone()
attn_mask[result != 0] = 1
return {
"input_ids": result,
"attention_mask": attn_mask,
}
else:
sot_token = self._tokenizer.encoder['<|startoftext|>']
eot_token = self._tokenizer.encoder['<|endoftext|>']
Expand Down
1 change: 1 addition & 0 deletions server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
'transformers': ['transformers>=4.16.2'],
'search': ['annlite>=0.3.10'],
'flash-attn': ['flash-attn'],
'cn_clip': ['cn_clip'],
},
classifiers=[
'Development Status :: 5 - Production/Stable',
Expand Down
2 changes: 2 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from clip_server.model.clip_onnx import CLIPOnnxModel
from clip_server.model.openclip_model import OpenCLIPModel
from clip_server.model.mclip_model import MultilingualCLIPModel
from clip_server.model.cnclip_model import CNClipModel


@pytest.mark.parametrize(
Expand All @@ -12,6 +13,7 @@
('RN50::openai', OpenCLIPModel),
('roberta-ViT-B-32::laion2b-s12b-b32k', OpenCLIPModel),
('M-CLIP/LABSE-Vit-L-14', MultilingualCLIPModel),
('CN-CLIP/ViT-B-16', CNClipModel),
],
)
def test_torch_model(name, model_cls):
Expand Down