From dabbe8bc3ef633e4460e1be3f1c06792fe08f00c Mon Sep 17 00:00:00 2001 From: Yang Ruiyi <50010436+Hippopotamus0308@users.noreply.github.com> Date: Thu, 9 Feb 2023 18:48:08 +1100 Subject: [PATCH] feat: add cn clip model (#888) * feat: basic structure * feat: switch from modelscope to cnclip * feat: add test * fix: add model name map * fix: test * chore: black format * fix: test * chore: lint check * fix: adjust encode output format * fix: add cn_clip installation into yaml file * fix: init * fix: init * fix: cn clip model text tokenization * chore: correct the quote --- .github/workflows/cd.yml | 1 + .github/workflows/ci.yml | 2 + server/clip_server/model/clip_model.py | 6 +++ server/clip_server/model/cnclip_model.py | 51 +++++++++++++++++++ server/clip_server/model/pretrained_models.py | 8 +++ server/clip_server/model/tokenization.py | 30 +++++++++-- server/setup.py | 1 + tests/test_model.py | 2 + 8 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 server/clip_server/model/cnclip_model.py diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 0a7e13059..cebda5f82 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -44,6 +44,7 @@ jobs: pip install --no-cache-dir "server/[onnx]" pip install --no-cache-dir "server/[transformers]" pip install --no-cache-dir "server/[search]" + pip install --no-cache-dir "server/[cn_clip]" - name: Test id: test run: | diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 924207527..38d773dc0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -113,6 +113,7 @@ jobs: pip install --no-cache-dir "server/[onnx]" pip install --no-cache-dir "server/[transformers]" pip install --no-cache-dir "server/[search]" + pip install --no-cache-dir "server/[cn_clip]" - name: Test id: test run: | @@ -164,6 +165,7 @@ jobs: } || { echo "flash attention was not installed." } + pip install --no-cache-dir "server/[cn_clip]" - name: Test id: test run: | diff --git a/server/clip_server/model/clip_model.py b/server/clip_server/model/clip_model.py index 7164cfea8..fc40ee75b 100644 --- a/server/clip_server/model/clip_model.py +++ b/server/clip_server/model/clip_model.py @@ -1,6 +1,7 @@ from clip_server.model.pretrained_models import ( _OPENCLIP_MODELS, _MULTILINGUALCLIP_MODELS, + _CNCLIP_MODELS, _VISUAL_MODEL_IMAGE_SIZE, ) @@ -34,6 +35,10 @@ def __new__(cls, name: str, **kwargs): from clip_server.model.mclip_model import MultilingualCLIPModel instance = super().__new__(MultilingualCLIPModel) + elif name in _CNCLIP_MODELS: + from clip_server.model.cnclip_model import CNClipModel + + instance = super().__new__(CNClipModel) else: raise ValueError( 'CLIP model {} not found; below is a list of all available models:\n{}'.format( @@ -43,6 +48,7 @@ def __new__(cls, name: str, **kwargs): '\t- {}\n'.format(i) for i in list(_OPENCLIP_MODELS.keys()) + list(_MULTILINGUALCLIP_MODELS.keys()) + + list(_CNCLIP_MODELS.keys()) ] ), ) diff --git a/server/clip_server/model/cnclip_model.py b/server/clip_server/model/cnclip_model.py new file mode 100644 index 000000000..a8761bae5 --- /dev/null +++ b/server/clip_server/model/cnclip_model.py @@ -0,0 +1,51 @@ +# Originally from https://github.com/OFA-Sys/Chinese-CLIP. MIT License. + +import torch + +from clip_server.model.clip_model import CLIPModel +from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE +from cn_clip.clip import load_from_name + +_CNCLIP_MODEL_MAPS = { + 'CN-CLIP/ViT-B-16': 'ViT-B-16', + 'CN-CLIP/ViT-L-14': 'ViT-L-14', + 'CN-CLIP/ViT-L-14-336': 'ViT-L-14-336', + 'CN-CLIP/ViT-H-14': 'ViT-H-14', + 'CN-CLIP/RN50': 'RN50', +} + + +class CNClipModel(CLIPModel): + def __init__( + self, + name: str, + device: str = 'cpu', + jit: bool = False, + dtype: str = None, + **kwargs + ): + super().__init__(name, **kwargs) + self._name = _CNCLIP_MODEL_MAPS[name] + + self._model, self._preprocess = load_from_name( + _CNCLIP_MODEL_MAPS[name], device=device + ) + self._model.eval() + + @staticmethod + def get_model_name(name: str): + return _CNCLIP_MODEL_MAPS[name] + + def encode_text(self, input_ids: 'torch.Tensor', **kwargs): + return self._model.encode_text(input_ids).detach() + + def encode_image(self, pixel_values: 'torch.Tensor', **kwargs): + return self._model.encode_image(pixel_values).detach() + + @property + def model_name(self): + return self.__class__.get_model_name(self._name) + + @property + def image_size(self): + return _VISUAL_MODEL_IMAGE_SIZE.get(self._name, None) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index fb11b13de..3494bee4b 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -101,6 +101,14 @@ 'M-CLIP/LABSE-Vit-L-14': (), } +_CNCLIP_MODELS = { + 'CN-CLIP/ViT-B-16': (), + 'CN-CLIP/ViT-L-14': (), + 'CN-CLIP/ViT-L-14-336': (), + 'CN-CLIP/ViT-H-14': (), + 'CN-CLIP/RN50': (), +} + _VISUAL_MODEL_IMAGE_SIZE = { 'RN50': 224, 'RN101': 224, diff --git a/server/clip_server/model/tokenization.py b/server/clip_server/model/tokenization.py index af1dcf7be..605d571f0 100644 --- a/server/clip_server/model/tokenization.py +++ b/server/clip_server/model/tokenization.py @@ -1,6 +1,9 @@ import torch from typing import List, Union -from clip_server.model.pretrained_models import _MULTILINGUALCLIP_MODELS +from clip_server.model.pretrained_models import ( + _MULTILINGUALCLIP_MODELS, + _CNCLIP_MODELS, +) class Tokenizer: @@ -10,6 +13,10 @@ def __init__(self, name: str, **kwargs): import transformers self._tokenizer = transformers.AutoTokenizer.from_pretrained(name) + elif name in _CNCLIP_MODELS: + import cn_clip.clip as cnclip + + self._tokenizer = cnclip else: from clip_server.model.simple_tokenizer import SimpleTokenizer @@ -23,13 +30,19 @@ def __call__( ): """ :param texts: An input string or a list of input strings to tokenize - :param context_length: The context length to use; all CLIP models use 77 as the context length. + :param context_length: The context length to use; all English CLIP models use 77 as the context length. + for Chinese CLIP models, context_length = 52, if the number of characters is bigger than 50, sentence will be truncate and omit the part left :param truncate: Whether to truncate the text in case its encoding is longer than the context length. :return: A dict of tokenized representations of the input strings and their corresponding attention masks with both shape = [batch size, context_length] """ - return self._tokenize(texts, context_length=context_length, truncate=truncate) + if self._name in _CNCLIP_MODELS: + return self._tokenize(texts, context_length=52) + else: + return self._tokenize( + texts, context_length=context_length, truncate=truncate + ) def _tokenize( self, @@ -52,6 +65,17 @@ def _tokenize( 'input_ids': result['input_ids'], 'attention_mask': result['attention_mask'], } + elif self._name in _CNCLIP_MODELS: + result = self._tokenizer.tokenize( + texts=texts, + context_length=52, # in all cnclip baseline model context length is 52 + ) + attn_mask = result.clone() + attn_mask[result != 0] = 1 + return { + "input_ids": result, + "attention_mask": attn_mask, + } else: sot_token = self._tokenizer.encoder['<|startoftext|>'] eot_token = self._tokenizer.encoder['<|endoftext|>'] diff --git a/server/setup.py b/server/setup.py index b75b014ac..7d850fabf 100644 --- a/server/setup.py +++ b/server/setup.py @@ -60,6 +60,7 @@ 'transformers': ['transformers>=4.16.2'], 'search': ['annlite>=0.3.10'], 'flash-attn': ['flash-attn'], + 'cn_clip': ['cn_clip'], }, classifiers=[ 'Development Status :: 5 - Production/Stable', diff --git a/tests/test_model.py b/tests/test_model.py index f7e44177b..75b870b2d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,6 +3,7 @@ from clip_server.model.clip_onnx import CLIPOnnxModel from clip_server.model.openclip_model import OpenCLIPModel from clip_server.model.mclip_model import MultilingualCLIPModel +from clip_server.model.cnclip_model import CNClipModel @pytest.mark.parametrize( @@ -12,6 +13,7 @@ ('RN50::openai', OpenCLIPModel), ('roberta-ViT-B-32::laion2b-s12b-b32k', OpenCLIPModel), ('M-CLIP/LABSE-Vit-L-14', MultilingualCLIPModel), + ('CN-CLIP/ViT-B-16', CNClipModel), ], ) def test_torch_model(name, model_cls):