From fa62d8e93baf2579b2934cc0ed8daca12c144d7d Mon Sep 17 00:00:00 2001 From: Alex Shan <36291011+shan-mx@users.noreply.github.com> Date: Thu, 21 Jul 2022 17:30:50 +0800 Subject: [PATCH] feat: support openclip&mclip models + refactor model loader (#774) * fix: draft commits * feat: add openclip+mclip support * feat: add openclip+mclip support * fix: update openclip model list * fix: import error * fix: import error * fix: import error * fix: remove executor clip_oc * fix: recovered procs in helper for older version onnx&trt executors * fix: recovered procs in helper for older version onnx&trt executors * fix: add openclip requirement * fix: recover helper.py * fix: add openclip requirement * fix: add openclip requirement * fix: recover helper.py * fix: refactor preprocessor (#776) * fix: refactor preprocessor * fix: error * fix: missing functions * fix: revert commit * fix: tests * fix: simple tokenizer * fix: clip model mixin * feat: support customized download * feat: support customized download * feat: support customized download * fix: init kwargs * fix: minor revision * fix: errors * f * fix: error * fix: update open ai model loading * fix: clean codes * fix: clean codes * fix: open clip base model name * fix: minor revision * fix: clean unused codes * fix: mclip image size * fix: mclip tokenizer error * fix: set padding * fix: set truncation * fix: update tokenizer api * fix: add unittest * fix: update license * fix: unittest Co-authored-by: numb3r3 Co-authored-by: felix-wang <35718120+numb3r3@users.noreply.github.com> --- LICENSE | 3 +- server/clip_server/executors/clip_onnx.py | 9 +- server/clip_server/executors/clip_tensorrt.py | 11 +- server/clip_server/executors/clip_torch.py | 22 +- server/clip_server/executors/helper.py | 16 +- server/clip_server/model/clip.py | 192 +----- server/clip_server/model/clip_model.py | 34 ++ server/clip_server/model/clip_onnx.py | 7 +- server/clip_server/model/mclip_model.py | 77 +++ server/clip_server/model/model.py | 554 ------------------ server/clip_server/model/openclip_model.py | 47 ++ server/clip_server/model/pretrained_models.py | 180 ++++++ server/clip_server/model/tokenization.py | 80 +++ server/setup.py | 1 + tests/test_model.py | 17 + tests/test_server.py | 33 +- tests/test_tokenization.py | 17 + 17 files changed, 505 insertions(+), 795 deletions(-) create mode 100644 server/clip_server/model/clip_model.py create mode 100644 server/clip_server/model/mclip_model.py delete mode 100644 server/clip_server/model/model.py create mode 100644 server/clip_server/model/openclip_model.py create mode 100644 server/clip_server/model/pretrained_models.py create mode 100644 server/clip_server/model/tokenization.py create mode 100644 tests/test_model.py create mode 100644 tests/test_tokenization.py diff --git a/LICENSE b/LICENSE index af4dec1d2..cf4866847 100644 --- a/LICENSE +++ b/LICENSE @@ -1,8 +1,7 @@ Copyright 2020-2022 Jina AI Limited. All rights reserved. -The following three files are licensed under MIT License via https://github.com/openai/CLIP Copyright (c) 2021 OpenAI +The following two files are licensed under MIT License via https://github.com/openai/CLIP Copyright (c) 2021 OpenAI server/clip_server/model/clip.py - server/clip_server/model/model.py server/clip_server/model/simple_tokenizer.py diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index 661f55908..ffacd30cc 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -12,6 +12,7 @@ ) from clip_server.model import clip from clip_server.model.clip_onnx import CLIPOnnxModel +from clip_server.model.tokenization import Tokenizer from jina import Executor, requests, DocumentArray @@ -31,10 +32,12 @@ def __init__( self._minibatch_size = minibatch_size self._traversal_paths = traversal_paths - self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name]) self._pool = ThreadPool(processes=num_worker_preprocess) self._model = CLIPOnnxModel(name, model_path) + self._tokenizer = Tokenizer(name) + + self._image_transform = clip._transform_ndarray(clip.MODEL_SIZE[name]) import torch @@ -84,7 +87,7 @@ def _preproc_images(self, docs: 'DocumentArray'): documentation='images preprocess time in seconds', ): return preproc_image( - docs, preprocess_fn=self._preprocess_tensor, return_np=True + docs, preprocess_fn=self._image_transform, return_np=True ) def _preproc_texts(self, docs: 'DocumentArray'): @@ -92,7 +95,7 @@ def _preproc_texts(self, docs: 'DocumentArray'): name='preprocess_texts_seconds', documentation='texts preprocess time in seconds', ): - return preproc_text(docs, return_np=True) + return preproc_text(docs, tokenizer=self._tokenizer, return_np=True) @requests(on='/rank') async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): diff --git a/server/clip_server/executors/clip_tensorrt.py b/server/clip_server/executors/clip_tensorrt.py index 5dc9af251..6dab79183 100644 --- a/server/clip_server/executors/clip_tensorrt.py +++ b/server/clip_server/executors/clip_tensorrt.py @@ -9,6 +9,7 @@ set_rank, ) from clip_server.model import clip +from clip_server.model.tokenization import Tokenizer from clip_server.model.clip_trt import CLIPTensorRTModel from jina import Executor, requests, DocumentArray @@ -25,7 +26,6 @@ def __init__( ): super().__init__(**kwargs) - self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name]) self._pool = ThreadPool(processes=num_worker_preprocess) self._minibatch_size = minibatch_size @@ -48,6 +48,9 @@ def __init__( self._model.start_engines() + self._tokenizer = Tokenizer(name) + self._image_transform = clip._transform_ndarray(clip.MODEL_SIZE[name]) + def _preproc_images(self, docs: 'DocumentArray'): with self.monitor( name='preprocess_images_seconds', @@ -55,7 +58,7 @@ def _preproc_images(self, docs: 'DocumentArray'): ): return preproc_image( docs, - preprocess_fn=self._preprocess_tensor, + preprocess_fn=self._image_transform, device=self._device, return_np=False, ) @@ -65,7 +68,9 @@ def _preproc_texts(self, docs: 'DocumentArray'): name='preprocess_texts_seconds', documentation='texts preprocess time in seconds', ): - return preproc_text(docs, device=self._device, return_np=False) + return preproc_text( + docs, tokenizer=self._tokenizer, device=self._device, return_np=False + ) @requests(on='/rank') async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index 59173307e..e2780d598 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -12,13 +12,15 @@ set_rank, ) from clip_server.model import clip +from clip_server.model.clip_model import CLIPModel +from clip_server.model.tokenization import Tokenizer from jina import Executor, requests, DocumentArray class CLIPEncoder(Executor): def __init__( self, - name: str = 'ViT-B/32', + name: str = 'ViT-B-32-quickgelu::openai', device: Optional[str] = None, jit: bool = False, num_worker_preprocess: int = 4, @@ -53,12 +55,12 @@ def __init__( # For more details, please see https://pytorch.org/docs/stable/generated/torch.set_num_threads.html torch.set_num_threads(max(num_threads, 1)) torch.set_num_interop_threads(1) + self._pool = ThreadPool(processes=num_worker_preprocess) - self._model, self._preprocess_tensor = clip.load( - name, device=self._device, jit=jit - ) + self._model = CLIPModel(name, device=self._device, jit=jit, **kwargs) + self._tokenizer = Tokenizer(name) - self._pool = ThreadPool(processes=num_worker_preprocess) + self._image_transform = clip._transform_ndarray(self._model.image_size) def _preproc_images(self, docs: 'DocumentArray'): with self.monitor( @@ -67,7 +69,7 @@ def _preproc_images(self, docs: 'DocumentArray'): ): return preproc_image( docs, - preprocess_fn=self._preprocess_tensor, + preprocess_fn=self._image_transform, device=self._device, return_np=False, ) @@ -77,7 +79,9 @@ def _preproc_texts(self, docs: 'DocumentArray'): name='preprocess_texts_seconds', documentation='texts preprocess time in seconds', ): - return preproc_text(docs, device=self._device, return_np=False) + return preproc_text( + docs, tokenizer=self._tokenizer, device=self._device, return_np=False + ) @requests(on='/rank') async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): @@ -108,7 +112,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): documentation='images encode time in seconds', ): minibatch.embeddings = ( - self._model.encode_image(batch_data['pixel_values']) + self._model.encode_image(**batch_data) .cpu() .numpy() .astype(np.float32) @@ -126,7 +130,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs): documentation='texts encode time in seconds', ): minibatch.embeddings = ( - self._model.encode_text(batch_data['input_ids']) + self._model.encode_text(**batch_data) .cpu() .numpy() .astype(np.float32) diff --git a/server/clip_server/executors/helper.py b/server/clip_server/executors/helper.py index 9ecb7238c..da24597b9 100644 --- a/server/clip_server/executors/helper.py +++ b/server/clip_server/executors/helper.py @@ -4,7 +4,8 @@ from docarray import Document, DocumentArray from docarray.math.distance.numpy import cosine -from clip_server.model import clip + +from clip_server.model.tokenization import Tokenizer def numpy_softmax(x: 'np.ndarray', axis: int = -1) -> 'np.ndarray': @@ -49,10 +50,13 @@ def preproc_image( def preproc_text( - da: 'DocumentArray', device: str = 'cpu', return_np: bool = False + da: 'DocumentArray', + tokenizer: 'Tokenizer', + device: str = 'cpu', + return_np: bool = False, ) -> Tuple['DocumentArray', Dict]: - inputs = clip.tokenize(da.texts) + inputs = tokenizer(da.texts) inputs['input_ids'] = inputs['input_ids'].detach() if return_np: @@ -113,3 +117,9 @@ def set_rank(docs, _logit_scale=np.exp(4.60517)): ) q.matches = final + + +def get_image_size(name: str): + from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE + + return _VISUAL_MODEL_IMAGE_SIZE[name] diff --git a/server/clip_server/model/clip.py b/server/clip_server/model/clip.py index 9e2fae77e..eae973ab5 100644 --- a/server/clip_server/model/clip.py +++ b/server/clip_server/model/clip.py @@ -5,16 +5,11 @@ import hashlib import shutil import urllib -import warnings -from typing import Union, List +from typing import List -import torch from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from clip_server.model.model import build_model -from clip_server.model.simple_tokenizer import SimpleTokenizer as _Tokenizer - try: from torchvision.transforms import InterpolationMode @@ -22,8 +17,6 @@ except ImportError: BICUBIC = Image.BICUBIC -__all__ = ['available_models', 'load', 'tokenize'] -_tokenizer = _Tokenizer() _S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch/' _MODELS = { @@ -195,186 +188,3 @@ def _transform_ndarray(n_px): def available_models() -> List[str]: '''Returns the names of available CLIP models''' return list(_MODELS.keys()) - - -def load( - name: str, - device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu', - jit: bool = False, - download_root: str = None, -): - """Load a CLIP model - - Parameters - ---------- - name : str - A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict - - device : Union[str, torch.device] - The device to put the loaded model - - jit : bool - Whether to load the optimized JIT model or more hackable non-JIT model (default). - - download_root: str - path to download the model files; by default, it uses '~/.cache/clip/' - - Returns - ------- - model : torch.nn.Module - The CLIP model - - preprocess : Callable[[PIL.Image], torch.Tensor] - A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input - """ - if name in _MODELS: - model_name, model_md5 = _MODELS[name] - model_path = _download( - url=_S3_BUCKET + model_name, - target_folder=download_root or os.path.expanduser('~/.cache/clip'), - md5sum=model_md5, - with_resume=True, - ) - elif os.path.isfile(name): - model_path = name - else: - raise RuntimeError( - f'Model {name} not found; available models = {available_models()}' - ) - - try: - # loading JIT archive - model = torch.jit.load(model_path, map_location=device if jit else 'cpu').eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn( - f'File {model_path} is not a JIT archive. Loading as a state dict instead' - ) - jit = False - state_dict = torch.load(model_path, map_location='cpu') - - if not jit: - model = build_model(state_dict or model.state_dict()).to(device) - if str(device) == 'cpu': - model.float() - return ( - model, - _transform_ndarray(model.visual.input_resolution), - ) - - # patch the device names - device_holder = torch.jit.trace( - lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] - ) - device_node = [ - n - for n in device_holder.graph.findAllNodes('prim::Constant') - if 'Device' in repr(n) - ][-1] - - def patch_device(module): - try: - graphs = [module.graph] if hasattr(module, 'graph') else [] - except RuntimeError: - graphs = [] - - if hasattr(module, 'forward1'): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes('prim::Constant'): - if 'value' in node.attributeNames() and str(node['value']).startswith( - 'cuda' - ): - node.copyAttributes(device_node) - - model.apply(patch_device) - patch_device(model.encode_image) - patch_device(model.encode_text) - - # patch dtype to float32 on CPU - if str(device) == 'cpu': - float_holder = torch.jit.trace( - lambda: torch.ones([]).float(), example_inputs=[] - ) - float_input = list(float_holder.graph.findNode('aten::to').inputs())[1] - float_node = float_input.node() - - def patch_float(module): - try: - graphs = [module.graph] if hasattr(module, 'graph') else [] - except RuntimeError: - graphs = [] - - if hasattr(module, 'forward1'): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes('aten::to'): - inputs = list(node.inputs()) - for i in [ - 1, - 2, - ]: # dtype can be the second or third argument to aten::to() - if inputs[i].node()['value'] == 5: - inputs[i].node().copyAttributes(float_node) - - model.apply(patch_float) - patch_float(model.encode_image) - patch_float(model.encode_text) - - model.float() - - return ( - model, - _transform_ndarray(model.input_resolution.item()), - ) - - -def tokenize( - texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True -) -> dict: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - - context_length : int - The context length to use; all CLIP models use 77 as the context length - - truncate: bool - Whether to truncate the text in case its encoding is longer than the context length - - Returns - ------- - A dict of tokenized representations of the input strings and their corresponding attention masks with both - shape = [batch size, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder['<|startoftext|>'] - eot_token = _tokenizer.encoder['<|endoftext|>'] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - - input_ids = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - attention_mask = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - if truncate: - tokens = tokens[:context_length] - tokens[-1] = eot_token - else: - raise RuntimeError( - f'Input {texts[i]} is too long for context length {context_length}' - ) - input_ids[i, : len(tokens)] = torch.tensor(tokens) - attention_mask[i, : len(tokens)] = 1 - - return {'input_ids': input_ids, 'attention_mask': attention_mask} diff --git a/server/clip_server/model/clip_model.py b/server/clip_server/model/clip_model.py new file mode 100644 index 000000000..0bce755d3 --- /dev/null +++ b/server/clip_server/model/clip_model.py @@ -0,0 +1,34 @@ +from clip_server.model.pretrained_models import ( + _OPENCLIP_MODELS, + _MULTILINGUALCLIP_MODELS, + _VISUAL_MODEL_IMAGE_SIZE, +) + + +class CLIPModel: + def __new__(cls, name: str, **kwargs): + if cls is CLIPModel: + if name in _OPENCLIP_MODELS: + from clip_server.model.openclip_model import OpenCLIPModel + + instance = super().__new__(OpenCLIPModel) + elif name in _MULTILINGUALCLIP_MODELS: + from clip_server.model.mclip_model import MultilingualCLIPModel + + instance = super().__new__(MultilingualCLIPModel) + else: + raise ValueError(f'The CLIP model name=`{name}` is not supported.') + else: + instance = super().__new__(cls) + return instance + + def __init__(self, name: str, **kwargs): + self._name = name + + @property + def model_name(self): + return self._name + + @property + def image_size(self): + return _VISUAL_MODEL_IMAGE_SIZE.get(self.model_name, None) diff --git a/server/clip_server/model/clip_onnx.py b/server/clip_server/model/clip_onnx.py index 1bbb8f57b..d85f51cf4 100644 --- a/server/clip_server/model/clip_onnx.py +++ b/server/clip_server/model/clip_onnx.py @@ -1,6 +1,7 @@ import os -from clip_server.model.clip import _download, available_models +from clip_server.model.clip import available_models +from clip_server.model.pretrained_models import download_model _S3_BUCKET = ( 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/' # Deprecated @@ -54,14 +55,14 @@ def __init__(self, name: str = None, model_path: str = None): f'~/.cache/clip/{name.replace("/", "-")}' ) textual_model_name, textual_model_md5 = _MODELS[name][0] - self._textual_path = _download( + self._textual_path = download_model( url=_S3_BUCKET_V2 + textual_model_name, target_folder=cache_dir, md5sum=textual_model_md5, with_resume=True, ) visual_model_name, visual_model_md5 = _MODELS[name][1] - self._visual_path = _download( + self._visual_path = download_model( url=_S3_BUCKET_V2 + visual_model_name, target_folder=cache_dir, md5sum=visual_model_md5, diff --git a/server/clip_server/model/mclip_model.py b/server/clip_server/model/mclip_model.py new file mode 100644 index 000000000..6e59653c3 --- /dev/null +++ b/server/clip_server/model/mclip_model.py @@ -0,0 +1,77 @@ +# Originally from https://github.com/FreddeFrallan/Multilingual-CLIP. MIT License, Copyright (c) 2022 Multilingual-CLIP + +import transformers +import torch +import open_clip + +from clip_server.model.clip_model import CLIPModel +from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE + +corresponding_clip_models = { + 'M-CLIP/XLM-Roberta-Large-Vit-B-32': ('ViT-B-32', 'openai'), + 'M-CLIP/XLM-Roberta-Large-Vi-L-14': ('ViT-L-14', 'openai'), + 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': ('ViT-B-16-plus-240', 'laion400m_e31'), + 'M-CLIP/LABSE-Vit-L-14': ('ViT-L-14', 'openai'), +} + + +class MCLIPConfig(transformers.PretrainedConfig): + model_type = "M-CLIP" + + def __init__( + self, + modelBase: str = 'xlm-roberta-large', + transformerDimSize: int = 1024, + imageDimSize: int = 768, + **kwargs + ): + self.transformerDimensions = transformerDimSize + self.numDims = imageDimSize + self.modelBase = modelBase + super().__init__(**kwargs) + + +class MultilingualCLIP(transformers.PreTrainedModel): + config_class = MCLIPConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.transformer = transformers.AutoModel.from_pretrained(config.modelBase) + self.LinearTransformation = torch.nn.Linear( + in_features=config.transformerDimensions, out_features=config.numDims + ) + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs): + embs = self.transformer( + input_ids=input_ids, attention_mask=attention_mask, **kwargs + )[0] + embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum( + dim=1 + )[:, None] + return self.LinearTransformation(embs) + + +class MultilingualCLIPModel(CLIPModel): + def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): + super().__init__(name, **kwargs) + self._mclip_model = MultilingualCLIP.from_pretrained(name) + + clip_name, clip_pretrained = corresponding_clip_models[name] + self._model = open_clip.create_model( + clip_name, pretrained=clip_pretrained, device=device, jit=jit + ) + self._clip_name = clip_name + + @property + def image_size(self): + return _VISUAL_MODEL_IMAGE_SIZE[self._clip_name] + + def encode_text( + self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', **kwargs + ): + return self._mclip_model( + input_ids=input_ids, attention_mask=attention_mask, **kwargs + ) + + def encode_image(self, pixel_values: torch.Tensor, **kwargs): + return self._model.encode_image(pixel_values) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py deleted file mode 100644 index c0b1a2296..000000000 --- a/server/clip_server/model/model.py +++ /dev/null @@ -1,554 +0,0 @@ -# Originally from https://github.com/openai/CLIP. MIT License, Copyright (c) 2021 OpenAI - -from collections import OrderedDict -from typing import Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - - self.relu = nn.ReLU(inplace=True) - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential( - OrderedDict( - [ - ('-1', nn.AvgPool2d(stride)), - ( - '0', - nn.Conv2d( - inplanes, - planes * self.expansion, - 1, - stride=1, - bias=False, - ), - ), - ('1', nn.BatchNorm2d(planes * self.expansion)), - ] - ) - ) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__( - self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None - ): - super().__init__() - self.positional_embedding = nn.Parameter( - torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 - ) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( - 2, 0, 1 - ) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat( - [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] - ), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - - return x[0] - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 'stem' convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): - super().__init__() - self.output_dim = output_dim - self.input_resolution = input_resolution - - # the 3-layer stem - self.conv1 = nn.Conv2d( - 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False - ) - self.bn1 = nn.BatchNorm2d(width // 2) - self.conv2 = nn.Conv2d( - width // 2, width // 2, kernel_size=3, padding=1, bias=False - ) - self.bn2 = nn.BatchNorm2d(width // 2) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.avgpool = nn.AvgPool2d(2) - self.relu = nn.ReLU(inplace=True) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d( - input_resolution // 32, embed_dim, heads, output_dim - ) - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def forward(self, x): - def stem(x): - for conv, bn in [ - (self.conv1, self.bn1), - (self.conv2, self.bn2), - (self.conv3, self.bn3), - ]: - x = self.relu(bn(conv(x))) - x = self.avgpool(x) - return x - - x = x.type(self.conv1.weight.dtype) - x = stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - - return x - - -class LayerNorm(nn.LayerNorm): - '''Subclass torch's LayerNorm to handle fp16.''' - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -class QuickGELU(nn.Module): - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential( - OrderedDict( - [ - ('c_fc', nn.Linear(d_model, d_model * 4)), - ('gelu', QuickGELU()), - ('c_proj', nn.Linear(d_model * 4, d_model)), - ] - ) - ) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x: torch.Tensor): - self.attn_mask = ( - self.attn_mask.to(dtype=x.dtype, device=x.device) - if self.attn_mask is not None - else None - ) - return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] - - def forward(self, x: torch.Tensor): - x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -class Transformer(nn.Module): - def __init__( - self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None - ): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.Sequential( - *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] - ) - - def forward(self, x: torch.Tensor): - return self.resblocks(x) - - -class VisionTransformer(nn.Module): - def __init__( - self, - input_resolution: int, - patch_size: int, - width: int, - layers: int, - heads: int, - output_dim: int, - ): - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False, - ) - - scale = width**-0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter( - scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) - ) - self.ln_pre = LayerNorm(width) - - self.transformer = Transformer(width, layers, heads) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [ - self.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), - x, - ], - dim=1, - ) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - x = x @ self.proj - - return x - - -class CLIP(nn.Module): - def __init__( - self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int, - ): - super().__init__() - - self.context_length = context_length - - if isinstance(vision_layers, (tuple, list)): - vision_heads = vision_width * 32 // 64 - self.visual = ModifiedResNet( - layers=vision_layers, - output_dim=embed_dim, - heads=vision_heads, - input_resolution=image_resolution, - width=vision_width, - ) - else: - vision_heads = vision_width // 64 - self.visual = VisionTransformer( - input_resolution=image_resolution, - patch_size=vision_patch_size, - width=vision_width, - layers=vision_layers, - heads=vision_heads, - output_dim=embed_dim, - ) - - self.transformer = Transformer( - width=transformer_width, - layers=transformer_layers, - heads=transformer_heads, - attn_mask=self.build_attention_mask(), - ) - - self.vocab_size = vocab_size - self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter( - torch.empty(self.context_length, transformer_width) - ) - self.ln_final = LayerNorm(transformer_width) - - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - self.initialize_parameters() - - def initialize_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - - if isinstance(self.visual, ModifiedResNet): - if self.visual.attnpool is not None: - std = self.visual.attnpool.c_proj.in_features**-0.5 - nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) - - for resnet_block in [ - self.visual.layer1, - self.visual.layer2, - self.visual.layer3, - self.visual.layer4, - ]: - for name, param in resnet_block.named_parameters(): - if name.endswith('bn3.weight'): - nn.init.zeros_(param) - - proj_std = (self.transformer.width**-0.5) * ( - (2 * self.transformer.layers) ** -0.5 - ) - attn_std = self.transformer.width**-0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float('-inf')) - mask.triu_(1) # zero out the lower diagonal - return mask - - @property - def dtype(self): - return self.visual.conv1.weight.dtype - - def encode_image(self, image): - return self.visual(image.type(self.dtype)) - - def encode_text(self, text): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.type(self.dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x).type(self.dtype) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x - - def forward(self, image, text): - image_features = self.encode_image(image) - text_features = self.encode_text(text) - - # normalized features - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - - # cosine similarity as logits - logit_scale = self.logit_scale.exp() - logits_per_image = logit_scale * image_features @ text_features.t() - logits_per_text = logits_per_image.t() - - # shape = [global_batch_size, global_batch_size] - return logits_per_image, logits_per_text - - -def convert_weights(model: nn.Module): - '''Convert applicable model parameters to fp16''' - - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - if isinstance(l, nn.MultiheadAttention): - for attr in [ - *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], - 'in_proj_bias', - 'bias_k', - 'bias_v', - ]: - tensor = getattr(l, attr) - if tensor is not None: - tensor.data = tensor.data.half() - - for name in ['text_projection', 'proj']: - if hasattr(l, name): - attr = getattr(l, name) - if attr is not None: - attr.data = attr.data.half() - - model.apply(_convert_weights_to_fp16) - - -def build_model(state_dict: dict): - vit = 'visual.proj' in state_dict - - if vit: - vision_width = state_dict['visual.conv1.weight'].shape[0] - vision_layers = len( - [ - k - for k in state_dict.keys() - if k.startswith('visual.') and k.endswith('.attn.in_proj_weight') - ] - ) - vision_patch_size = state_dict['visual.conv1.weight'].shape[-1] - grid_size = round( - (state_dict['visual.positional_embedding'].shape[0] - 1) ** 0.5 - ) - image_resolution = vision_patch_size * grid_size - else: - counts: list = [ - len( - set( - k.split('.')[2] - for k in state_dict - if k.startswith(f'visual.layer{b}') - ) - ) - for b in [1, 2, 3, 4] - ] - vision_layers = tuple(counts) - vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0] - output_width = round( - (state_dict['visual.attnpool.positional_embedding'].shape[0] - 1) ** 0.5 - ) - vision_patch_size = None - assert ( - output_width**2 + 1 - == state_dict['visual.attnpool.positional_embedding'].shape[0] - ) - image_resolution = output_width * 32 - - embed_dim = state_dict['text_projection'].shape[1] - context_length = state_dict['positional_embedding'].shape[0] - vocab_size = state_dict['token_embedding.weight'].shape[0] - transformer_width = state_dict['ln_final.weight'].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len( - set( - k.split('.')[2] - for k in state_dict - if k.startswith(f'transformer.resblocks') - ) - ) - - model = CLIP( - embed_dim, - image_resolution, - vision_layers, - vision_width, - vision_patch_size, - context_length, - vocab_size, - transformer_width, - transformer_heads, - transformer_layers, - ) - - for key in ['input_resolution', 'context_length', 'vocab_size']: - if key in state_dict: - del state_dict[key] - - convert_weights(model) - model.load_state_dict(state_dict) - return model.eval() diff --git a/server/clip_server/model/openclip_model.py b/server/clip_server/model/openclip_model.py new file mode 100644 index 000000000..d96ce6509 --- /dev/null +++ b/server/clip_server/model/openclip_model.py @@ -0,0 +1,47 @@ +# Originally from https://github.com/mlfoundations/open_clip. +# +# Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, +# Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, +# John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, +# Ludwig Schmidt + +from typing import TYPE_CHECKING + +from clip_server.model.clip_model import CLIPModel +from clip_server.model.pretrained_models import get_model_url_md5, download_model +import open_clip +from open_clip.openai import load_openai_model + +if TYPE_CHECKING: + import torch + + +class OpenCLIPModel(CLIPModel): + def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): + super().__init__(name, **kwargs) + + model_url, md5sum = get_model_url_md5(name) + if model_url: + model_path = download_model(model_url, md5sum=md5sum) + self._model = load_openai_model(model_path, device=device, jit=jit) + self._model_name = name + else: + model_name, pretrained = name.split('::') + self._model = open_clip.create_model( + model_name, pretrained=pretrained, device=device, jit=jit + ) + self._model_name = model_name + + @property + def model_name(self): + if self._model_name == 'ViT-L/14@336px': + return 'ViT-L-14-336' + elif self._model_name.endswith('-quickgelu'): + return self._model_name[:-10] + return self._model_name.replace('/', '-') + + def encode_text(self, input_ids: 'torch.Tensor', **kwargs): + return self._model.encode_text(input_ids) + + def encode_image(self, pixel_values: 'torch.Tensor', **kwargs): + return self._model.encode_image(pixel_values) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py new file mode 100644 index 000000000..367578257 --- /dev/null +++ b/server/clip_server/model/pretrained_models.py @@ -0,0 +1,180 @@ +import os +import hashlib +import shutil +import urllib + + +_OPENCLIP_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch' +_OPENCLIP_MODELS = { + 'RN50::openai': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), + 'RN50::yfcc15m': (), + 'RN50::cc12m': (), + 'RN50-quickgelu::openai': (), + 'RN50-quickgelu::yfcc15m': (), + 'RN50-quickgelu::cc12m': (), + 'RN101::openai': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'), + 'RN101::yfcc15m': (), + 'RN101-quickgelu::openai': (), + 'RN101-quickgelu::yfcc15m': (), + 'RN50x4::openai': ('RN50x4.pt', '03830990bc768e82f7fb684cde7e5654'), + 'RN50x16::openai': ('RN50x16.pt', '83d63878a818c65d0fb417e5fab1e8fe'), + 'RN50x64::openai': ('RN50x64.pt', 'a6631a0de003c4075d286140fc6dd637'), + 'ViT-B-32::openai': ('ViT-B-32.pt', '3ba34e387b24dfe590eeb1ae6a8a122b'), + 'ViT-B-32::laion2b_e16': (), + 'ViT-B-32::laion400m_e31': (), + 'ViT-B-32::laion400m_e32': (), + 'ViT-B-32-quickgelu::openai': (), + 'ViT-B-32-quickgelu::laion400m_e31': (), + 'ViT-B-32-quickgelu::laion400m_e32': (), + 'ViT-B-16::openai': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'), + 'ViT-B-16::laion400m_e31': (), + 'ViT-B-16::laion400m_e32': (), + 'ViT-B-16-plus-240::laion400m_e31': (), + 'ViT-B-16-plus-240::laion400m_e32': (), + 'ViT-L-14::openai': ('ViT-L-14.pt', '096db1af569b284eb76b3881534822d9'), + 'ViT-L-14-336::openai': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'), + # older version name format + 'RN50': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), + 'RN101': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'), + 'RN50x4': ('RN50x4.pt', '03830990bc768e82f7fb684cde7e5654'), + 'RN50x16': ('RN50x16.pt', '83d63878a818c65d0fb417e5fab1e8fe'), + 'RN50x64': ('RN50x64.pt', 'a6631a0de003c4075d286140fc6dd637'), + 'ViT-B/32': ('ViT-B-32.pt', '3ba34e387b24dfe590eeb1ae6a8a122b'), + 'ViT-B/16': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'), + 'ViT-L/14': ('ViT-L-14.pt', '096db1af569b284eb76b3881534822d9'), + 'ViT-L/14@336px': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'), +} + +_MULTILINGUALCLIP_MODELS = { + 'M-CLIP/XLM-Roberta-Large-Vit-B-32': (), + 'M-CLIP/XLM-Roberta-Large-Vit-L-14': (), + 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': (), + 'M-CLIP/LABSE-Vit-L-14': (), +} + +_VISUAL_MODEL_IMAGE_SIZE = { + 'RN50': 224, + 'RN101': 224, + 'RN50x4': 288, + 'RN50x16': 384, + 'RN50x64': 448, + 'ViT-B-32': 224, + 'ViT-B-16': 224, + 'ViT-B-16-plus-240': 240, + 'ViT-B-16-plus-240': 240, + 'ViT-L-14': 224, + 'ViT-L-14-336': 336, + 'Vit-B-16Plus': 240, +} + + +def md5file(filename: str): + hash_md5 = hashlib.md5() + with open(filename, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + + return hash_md5.hexdigest() + + +def get_model_url_md5(name: str): + model_pretrained = _OPENCLIP_MODELS[name] + if len(model_pretrained) == 0: # not on s3 + return None, None + else: + return (_OPENCLIP_S3_BUCKET + '/' + model_pretrained[0], model_pretrained[1]) + + +def download_model( + url: str, + target_folder: str = os.path.expanduser("~/.cache/clip"), + md5sum: str = None, + with_resume: bool = True, + max_attempts: int = 3, +) -> str: + os.makedirs(target_folder, exist_ok=True) + filename = os.path.basename(url) + + download_target = os.path.join(target_folder, filename) + + if os.path.exists(download_target): + if not os.path.isfile(download_target): + raise FileExistsError(f'{download_target} exists and is not a regular file') + + actual_md5sum = md5file(download_target) + if (not md5sum) or actual_md5sum == md5sum: + return download_target + + from rich.progress import ( + DownloadColumn, + Progress, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, + ) + + progress = Progress( + " \n", # divide this bar from Flow's bar + TextColumn("[bold blue]{task.fields[filename]}", justify="right"), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + DownloadColumn(), + "•", + TransferSpeedColumn(), + "•", + TimeRemainingColumn(), + ) + + with progress: + task = progress.add_task('download', filename=filename, start=False) + + for _ in range(max_attempts): + tmp_file_path = download_target + '.part' + resume_byte_pos = ( + os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0 + ) + + try: + # resolve the 403 error by passing a valid user-agent + req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + total_bytes = int( + urllib.request.urlopen(req).info().get('Content-Length', -1) + ) + mode = 'ab' if (with_resume and resume_byte_pos) else 'wb' + + with open(tmp_file_path, mode) as output: + progress.update(task, total=total_bytes) + progress.start_task(task) + + if resume_byte_pos and with_resume: + progress.update(task, advance=resume_byte_pos) + req.headers['Range'] = f'bytes={resume_byte_pos}-' + + with urllib.request.urlopen(req) as source: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + progress.update(task, advance=len(buffer)) + + actual_md5 = md5file(tmp_file_path) + if (md5sum and actual_md5 == md5sum) or (not md5sum): + shutil.move(tmp_file_path, download_target) + return download_target + else: + os.remove(tmp_file_path) + raise RuntimeError( + f'MD5 mismatch: expected {md5sum}, got {actual_md5}' + ) + + except Exception as ex: + progress.console.print( + f'Failed to download {url} with {ex!r} at the {_}th attempt' + ) + progress.reset(task) + + raise RuntimeError( + f'Failed to download {url} within retry limit {max_attempts}' + ) diff --git a/server/clip_server/model/tokenization.py b/server/clip_server/model/tokenization.py new file mode 100644 index 000000000..af1dcf7be --- /dev/null +++ b/server/clip_server/model/tokenization.py @@ -0,0 +1,80 @@ +import torch +from typing import List, Union +from clip_server.model.pretrained_models import _MULTILINGUALCLIP_MODELS + + +class Tokenizer: + def __init__(self, name: str, **kwargs): + self._name = name + if name in _MULTILINGUALCLIP_MODELS: + import transformers + + self._tokenizer = transformers.AutoTokenizer.from_pretrained(name) + else: + from clip_server.model.simple_tokenizer import SimpleTokenizer + + self._tokenizer = SimpleTokenizer() + + def __call__( + self, + texts: Union[str, List[str]], + context_length: int = 77, + truncate: bool = True, + ): + """ + :param texts: An input string or a list of input strings to tokenize + :param context_length: The context length to use; all CLIP models use 77 as the context length. + :param truncate: Whether to truncate the text in case its encoding is longer than the context length. + + :return: A dict of tokenized representations of the input strings and their corresponding attention masks with both + shape = [batch size, context_length] + """ + return self._tokenize(texts, context_length=context_length, truncate=truncate) + + def _tokenize( + self, + texts: Union[str, List[str]], + context_length: int = 77, + truncate: bool = True, + ) -> dict: + if isinstance(texts, str): + texts = [texts] + if self._name in _MULTILINGUALCLIP_MODELS: + result = self._tokenizer( + texts, + max_length=context_length, + return_attention_mask=True, + return_tensors='pt', + padding=True, + truncation=True, + ) + return { + 'input_ids': result['input_ids'], + 'attention_mask': result['attention_mask'], + } + else: + sot_token = self._tokenizer.encoder['<|startoftext|>'] + eot_token = self._tokenizer.encoder['<|endoftext|>'] + all_tokens = [ + [sot_token] + self._tokenizer.encode(text) + [eot_token] + for text in texts + ] + + input_ids = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + attention_mask = torch.zeros( + len(all_tokens), context_length, dtype=torch.long + ) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError( + f'Input {texts[i]} is too long for context length {context_length}' + ) + input_ids[i, : len(tokens)] = torch.tensor(tokens) + attention_mask[i, : len(tokens)] = 1 + + return {'input_ids': input_ids, 'attention_mask': attention_mask} diff --git a/server/setup.py b/server/setup.py index 9105bb6fb..3dbade3bc 100644 --- a/server/setup.py +++ b/server/setup.py @@ -49,6 +49,7 @@ 'torchvision', 'jina>=3.6.0', 'prometheus-client', + 'open_clip_torch', ], extras_require={ 'onnx': [ diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 000000000..4dc57c635 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,17 @@ +import pytest +from clip_server.model.clip_model import CLIPModel +from clip_server.model.openclip_model import OpenCLIPModel +from clip_server.model.mclip_model import MultilingualCLIPModel + + +@pytest.mark.parametrize( + 'name, model_cls', + [ + ('ViT-L/14@336px', OpenCLIPModel), + ('RN101-quickgelu::openai', OpenCLIPModel), + ('M-CLIP/XLM-Roberta-Large-Vit-B-32', MultilingualCLIPModel), + ], +) +def test_model_name(name, model_cls): + model = CLIPModel(name) + assert model.__class__ == model_cls diff --git a/tests/test_server.py b/tests/test_server.py index a9b476149..523a4cbbb 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -70,37 +70,16 @@ def test_server_download_not_regular_file(tmpdir): ) -def test_make_onnx_flow_custom_path_wrong_name(port_generator): +def test_make_onnx_flow_wrong_name_path(port_generator): from clip_server.executors.clip_onnx import CLIPEncoder - f = Flow(port=port_generator()).add( - name='onnx', - uses=CLIPEncoder, - uses_with={ - 'name': 'ABC', - 'model_path': os.path.expanduser('~/.cache/clip/ViT-B-32'), - }, - ) - with pytest.raises(Exception) as info: - with f: - f.post('/', Document(text='Hello world')) - - -@pytest.mark.parametrize('path', ['ABC', os.path.expanduser('~/.cache/')]) -def test_make_onnx_flow_custom_path_wrong_path(port_generator, path): - from clip_server.executors.clip_onnx import CLIPEncoder + with pytest.raises(Exception): + encoder = CLIPEncoder( + 'ABC', model_path=os.path.expanduser('~/.cache/clip/ViT-B-32') + ) - f = Flow(port=port_generator()).add( - name='onnx', - uses=CLIPEncoder, - uses_with={ - 'name': 'ViT-B/32', - 'model_path': path, - }, - ) with pytest.raises(Exception) as info: - with f: - f.post('/', Document(text='Hello world')) + encoder = CLIPEncoder('ViT-B/32', model_path='~/.cache/') @pytest.mark.parametrize( diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py new file mode 100644 index 000000000..0fb954ca1 --- /dev/null +++ b/tests/test_tokenization.py @@ -0,0 +1,17 @@ +import pytest +from clip_server.model.tokenization import Tokenizer + + +@pytest.mark.parametrize( + 'name', ['ViT-L/14@336px', 'M-CLIP/XLM-Roberta-Large-Vit-B-32'] +) +def test_tokenizer_name(name): + tokenizer = Tokenizer(name) + + result = tokenizer('hello world') + assert result['input_ids'].shape == result['attention_mask'].shape + assert result['input_ids'].shape[0] == 1 + + result = tokenizer(['hello world', 'welcome to the world']) + assert result['input_ids'].shape == result['attention_mask'].shape + assert result['input_ids'].shape[0] == 2