Skip to content

Commit

Permalink
feat: add cn clip model (#888)
Browse files Browse the repository at this point in the history
* feat: basic structure

* feat: switch from modelscope to cnclip

* feat: add test

* fix: add model name map

* fix: test

* chore: black format

* fix: test

* chore: lint check

* fix: adjust encode output format

* fix: add cn_clip installation into yaml file

* fix: init

* fix: init

* fix: cn clip model text tokenization

* chore: correct the quote
  • Loading branch information
Hippopotamus0308 authored Feb 9, 2023
1 parent 8a576c5 commit dabbe8b
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install --no-cache-dir "server/[cn_clip]"
- name: Test
id: test
run: |
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install --no-cache-dir "server/[cn_clip]"
- name: Test
id: test
run: |
Expand Down Expand Up @@ -164,6 +165,7 @@ jobs:
} || {
echo "flash attention was not installed."
}
pip install --no-cache-dir "server/[cn_clip]"
- name: Test
id: test
run: |
Expand Down
6 changes: 6 additions & 0 deletions server/clip_server/model/clip_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from clip_server.model.pretrained_models import (
_OPENCLIP_MODELS,
_MULTILINGUALCLIP_MODELS,
_CNCLIP_MODELS,
_VISUAL_MODEL_IMAGE_SIZE,
)

Expand Down Expand Up @@ -34,6 +35,10 @@ def __new__(cls, name: str, **kwargs):
from clip_server.model.mclip_model import MultilingualCLIPModel

instance = super().__new__(MultilingualCLIPModel)
elif name in _CNCLIP_MODELS:
from clip_server.model.cnclip_model import CNClipModel

instance = super().__new__(CNClipModel)
else:
raise ValueError(
'CLIP model {} not found; below is a list of all available models:\n{}'.format(
Expand All @@ -43,6 +48,7 @@ def __new__(cls, name: str, **kwargs):
'\t- {}\n'.format(i)
for i in list(_OPENCLIP_MODELS.keys())
+ list(_MULTILINGUALCLIP_MODELS.keys())
+ list(_CNCLIP_MODELS.keys())
]
),
)
Expand Down
51 changes: 51 additions & 0 deletions server/clip_server/model/cnclip_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Originally from https://github.com/OFA-Sys/Chinese-CLIP. MIT License.

import torch

from clip_server.model.clip_model import CLIPModel
from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
from cn_clip.clip import load_from_name

_CNCLIP_MODEL_MAPS = {
'CN-CLIP/ViT-B-16': 'ViT-B-16',
'CN-CLIP/ViT-L-14': 'ViT-L-14',
'CN-CLIP/ViT-L-14-336': 'ViT-L-14-336',
'CN-CLIP/ViT-H-14': 'ViT-H-14',
'CN-CLIP/RN50': 'RN50',
}


class CNClipModel(CLIPModel):
def __init__(
self,
name: str,
device: str = 'cpu',
jit: bool = False,
dtype: str = None,
**kwargs
):
super().__init__(name, **kwargs)
self._name = _CNCLIP_MODEL_MAPS[name]

self._model, self._preprocess = load_from_name(
_CNCLIP_MODEL_MAPS[name], device=device
)
self._model.eval()

@staticmethod
def get_model_name(name: str):
return _CNCLIP_MODEL_MAPS[name]

def encode_text(self, input_ids: 'torch.Tensor', **kwargs):
return self._model.encode_text(input_ids).detach()

def encode_image(self, pixel_values: 'torch.Tensor', **kwargs):
return self._model.encode_image(pixel_values).detach()

@property
def model_name(self):
return self.__class__.get_model_name(self._name)

@property
def image_size(self):
return _VISUAL_MODEL_IMAGE_SIZE.get(self._name, None)
8 changes: 8 additions & 0 deletions server/clip_server/model/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@
'M-CLIP/LABSE-Vit-L-14': (),
}

_CNCLIP_MODELS = {
'CN-CLIP/ViT-B-16': (),
'CN-CLIP/ViT-L-14': (),
'CN-CLIP/ViT-L-14-336': (),
'CN-CLIP/ViT-H-14': (),
'CN-CLIP/RN50': (),
}

_VISUAL_MODEL_IMAGE_SIZE = {
'RN50': 224,
'RN101': 224,
Expand Down
30 changes: 27 additions & 3 deletions server/clip_server/model/tokenization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
from typing import List, Union
from clip_server.model.pretrained_models import _MULTILINGUALCLIP_MODELS
from clip_server.model.pretrained_models import (
_MULTILINGUALCLIP_MODELS,
_CNCLIP_MODELS,
)


class Tokenizer:
Expand All @@ -10,6 +13,10 @@ def __init__(self, name: str, **kwargs):
import transformers

self._tokenizer = transformers.AutoTokenizer.from_pretrained(name)
elif name in _CNCLIP_MODELS:
import cn_clip.clip as cnclip

self._tokenizer = cnclip
else:
from clip_server.model.simple_tokenizer import SimpleTokenizer

Expand All @@ -23,13 +30,19 @@ def __call__(
):
"""
:param texts: An input string or a list of input strings to tokenize
:param context_length: The context length to use; all CLIP models use 77 as the context length.
:param context_length: The context length to use; all English CLIP models use 77 as the context length.
for Chinese CLIP models, context_length = 52, if the number of characters is bigger than 50, sentence will be truncate and omit the part left
:param truncate: Whether to truncate the text in case its encoding is longer than the context length.
:return: A dict of tokenized representations of the input strings and their corresponding attention masks with both
shape = [batch size, context_length]
"""
return self._tokenize(texts, context_length=context_length, truncate=truncate)
if self._name in _CNCLIP_MODELS:
return self._tokenize(texts, context_length=52)
else:
return self._tokenize(
texts, context_length=context_length, truncate=truncate
)

def _tokenize(
self,
Expand All @@ -52,6 +65,17 @@ def _tokenize(
'input_ids': result['input_ids'],
'attention_mask': result['attention_mask'],
}
elif self._name in _CNCLIP_MODELS:
result = self._tokenizer.tokenize(
texts=texts,
context_length=52, # in all cnclip baseline model context length is 52
)
attn_mask = result.clone()
attn_mask[result != 0] = 1
return {
"input_ids": result,
"attention_mask": attn_mask,
}
else:
sot_token = self._tokenizer.encoder['<|startoftext|>']
eot_token = self._tokenizer.encoder['<|endoftext|>']
Expand Down
1 change: 1 addition & 0 deletions server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
'transformers': ['transformers>=4.16.2'],
'search': ['annlite>=0.3.10'],
'flash-attn': ['flash-attn'],
'cn_clip': ['cn_clip'],
},
classifiers=[
'Development Status :: 5 - Production/Stable',
Expand Down
2 changes: 2 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from clip_server.model.clip_onnx import CLIPOnnxModel
from clip_server.model.openclip_model import OpenCLIPModel
from clip_server.model.mclip_model import MultilingualCLIPModel
from clip_server.model.cnclip_model import CNClipModel


@pytest.mark.parametrize(
Expand All @@ -12,6 +13,7 @@
('RN50::openai', OpenCLIPModel),
('roberta-ViT-B-32::laion2b-s12b-b32k', OpenCLIPModel),
('M-CLIP/LABSE-Vit-L-14', MultilingualCLIPModel),
('CN-CLIP/ViT-B-16', CNClipModel),
],
)
def test_torch_model(name, model_cls):
Expand Down

0 comments on commit dabbe8b

Please sign in to comment.