From 5524f460ab1091d74e7a479257c239142557f3dc Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Wed, 16 Nov 2022 16:58:15 +0800 Subject: [PATCH 01/33] feat: bump openclip to v2.5.0 (#859) * feat: bump openclip to v2.5.0 * fix: conflicts * fix: default fp32 on cpu and fp16 on gpu * feat: add two new models * fix: remove debug * fix: add roberta models (test) * fix: model name xlm * fix: (wip) --- server/clip_server/model/mclip_model.py | 8 +- server/clip_server/model/model.py | 868 +++++++----------- server/clip_server/model/modified_resnet.py | 212 +++++ server/clip_server/model/openclip_model.py | 6 +- server/clip_server/model/pretrained_models.py | 10 + .../model/pretrained_text_encode.py | 197 ++++ server/clip_server/model/transformer.py | 423 +++++++++ server/clip_server/model/utils.py | 63 ++ server/setup.py | 2 +- 9 files changed, 1231 insertions(+), 558 deletions(-) create mode 100644 server/clip_server/model/modified_resnet.py create mode 100644 server/clip_server/model/pretrained_text_encode.py create mode 100644 server/clip_server/model/transformer.py create mode 100644 server/clip_server/model/utils.py diff --git a/server/clip_server/model/mclip_model.py b/server/clip_server/model/mclip_model.py index c85519501..49e7f6198 100644 --- a/server/clip_server/model/mclip_model.py +++ b/server/clip_server/model/mclip_model.py @@ -2,6 +2,7 @@ import transformers import torch +from transformers.models.roberta.modeling_roberta import RobertaModel from clip_server.model.clip_model import CLIPModel from clip_server.model.openclip_model import OpenCLIPModel @@ -11,6 +12,8 @@ 'M-CLIP/XLM-Roberta-Large-Vit-L-14': 'ViT-L-14::openai', 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': 'ViT-B-16-plus-240::laion400m_e31', 'M-CLIP/LABSE-Vit-L-14': 'ViT-L-14::openai', + 'roberta-base': '', + 'xlm-roberta-base': '', } @@ -53,7 +56,10 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwarg class MultilingualCLIPModel(CLIPModel): def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): super().__init__(name, **kwargs) - self._mclip_model = MultilingualCLIP.from_pretrained(name) + if name in ('roberta-base', 'xlm-roberta-base'): + self._mclip_model = RobertaModel.from_pretrained(name) + else: + self._mclip_model = MultilingualCLIP.from_pretrained(name) self._mclip_model.to(device=device) self._mclip_model.eval() self._model = OpenCLIPModel(_CLIP_MODEL_MAPS[name], device=device, jit=jit) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index e87c56411..6cda54607 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -8,432 +8,31 @@ Ludwig Schmidt """ +import logging +import math import warnings -import collections.abc -from collections import OrderedDict from dataclasses import dataclass -from typing import Tuple, Union, Callable, Optional -from itertools import repeat +from typing import Tuple, Union, Optional from copy import deepcopy import numpy as np import torch import torch.nn.functional as F from torch import nn -from torch.utils.checkpoint import checkpoint from open_clip.timm_model import TimmModel -from open_clip.utils import freeze_batch_norm_2d from open_clip.factory import _MODEL_CONFIGS - -# Use flash attention -try: - from clip_server.model.flash_attention import MultiheadAttention - - FLASH_ATTENTION_AVAILABLE = True -except: - FLASH_ATTENTION_AVAILABLE = False - - -# From PyTorch internals -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -to_2tuple = _ntuple(2) - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.act1 = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.act2 = nn.ReLU(inplace=True) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.act3 = nn.ReLU(inplace=True) - - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential( - OrderedDict( - [ - ("-1", nn.AvgPool2d(stride)), - ( - "0", - nn.Conv2d( - inplanes, - planes * self.expansion, - 1, - stride=1, - bias=False, - ), - ), - ("1", nn.BatchNorm2d(planes * self.expansion)), - ] - ) - ) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.act1(self.bn1(self.conv1(x))) - out = self.act2(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.act3(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__( - self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None - ): - super().__init__() - self.positional_embedding = nn.Parameter( - torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 - ) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( - 2, 0, 1 - ) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat( - [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] - ), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - - return x[0] - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, image_size=224, width=64): - super().__init__() - self.output_dim = output_dim - self.image_size = image_size - - # the 3-layer stem - self.conv1 = nn.Conv2d( - 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False - ) - self.bn1 = nn.BatchNorm2d(width // 2) - self.act1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d( - width // 2, width // 2, kernel_size=3, padding=1, bias=False - ) - self.bn2 = nn.BatchNorm2d(width // 2) - self.act2 = nn.ReLU(inplace=True) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.act3 = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(2) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) - - self.init_parameters() - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def init_parameters(self): - if self.attnpool is not None: - std = self.attnpool.c_proj.in_features**-0.5 - nn.init.normal_(self.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert ( - unlocked_groups == 0 - ), 'partial locking not currently supported for this model' - for param in self.parameters(): - param.requires_grad = False - if freeze_bn_stats: - freeze_batch_norm_2d(self) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - # FIXME support for non-transformer - pass - - def stem(self, x): - x = self.act1(self.bn1(self.conv1(x))) - x = self.act2(self.bn2(self.conv2(x))) - x = self.act3(self.bn3(self.conv3(x))) - x = self.avgpool(x) - return x - - def forward(self, x): - # To handle fp16 inference - x = x.type(self.conv1.weight.dtype) - x = self.stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - - return x - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - x = F.layer_norm( - x.type(torch.float32), - self.normalized_shape, - self.weight, - self.bias, - self.eps, - ) - return x.to(orig_type) - - -class QuickGELU(nn.Module): - # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - scale_cosine_attn: bool = False, - scale_heads: bool = False, - scale_attn: bool = False, - scale_fc: bool = False, - ): - super().__init__() - head_dim = d_model // n_head - self.flash_attention = head_dim % 8 == 0 and head_dim <= 128 - - self.ln_1 = LayerNorm(d_model) - self.attn = ( - MultiheadAttention(d_model, n_head) - if FLASH_ATTENTION_AVAILABLE - and torch.cuda.is_available() - and self.flash_attention - else nn.MultiheadAttention(d_model, n_head) - ) - self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity() - - self.ln_2 = LayerNorm(d_model) - mlp_width = int(d_model * mlp_ratio) - self.mlp = nn.Sequential( - OrderedDict( - [ - ("c_fc", nn.Linear(d_model, mlp_width)), - ('ln', LayerNorm(mlp_width) if scale_fc else nn.Identity()), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)), - ] - ) - ) - - def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - if attn_mask is not None: - attn_mask = attn_mask.to(dtype=x.dtype, device=x.device) - x = x + self.ln_attn(self.attention(self.ln_1(x), attn_mask=attn_mask)) - x = x + self.mlp(self.ln_2(x)) - return x - - -class Transformer(nn.Module): - def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - ): - super().__init__() - self.width = width - self.layers = layers - self.grad_checkpointing = False - - self.resblocks = nn.ModuleList( - [ - ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer) - for _ in range(layers) - ] - ) - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - for r in self.resblocks: - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - return x - - -class VisualTransformer(nn.Module): - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - output_dim: int, - act_layer: Callable = nn.GELU, - ): - super().__init__() - self.image_size = to_2tuple(image_size) - self.patch_size = to_2tuple(patch_size) - self.grid_size = ( - self.image_size[0] // self.patch_size[0], - self.image_size[1] // self.patch_size[1], - ) - self.output_dim = output_dim - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False, - ) - - scale = width**-0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter( - scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width) - ) - self.ln_pre = LayerNorm(width) - - self.transformer = Transformer( - width, layers, heads, mlp_ratio, act_layer=act_layer - ) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert ( - unlocked_groups == 0 - ), 'partial locking not currently supported for this model' - for param in self.parameters(): - param.requires_grad = False - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.transformer.grad_checkpointing = enable - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [ - self.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), - x, - ], - dim=1, - ) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - x = x @ self.proj - - return x +from .modified_resnet import ModifiedResNet +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + TextTransformer, + VisionTransformer, + Attention, +) +from .utils import to_2tuple +from .pretrained_text_encode import PreTrainedTextEncoder @dataclass @@ -444,6 +43,7 @@ class CLIPVisionCfg: mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value timm_model_name: str = ( None # a valid model name overrides layers, width, patch_size ) @@ -456,6 +56,7 @@ class CLIPVisionCfg: timm_proj: str = ( 'linear' # linear projection for timm model output ('linear', 'mlp', '') ) + timm_proj_bias: bool = False # enable bias final projection @dataclass @@ -465,116 +66,141 @@ class CLIPTextCfg: width: int = 512 heads: int = 8 layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + act_layer = ( + nn.GELU + ) # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = ( + LayerNormFp32 + if cast_dtype in (torch.float16, torch.bfloat16) + else LayerNorm + ) + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + # TODO: adapt this + # ls_init_value=vision_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return visual -class CLIP(nn.Module): - def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - ): - super().__init__() - if isinstance(vision_cfg, dict): - vision_cfg = CLIPVisionCfg(**vision_cfg) - if isinstance(text_cfg, dict): - text_cfg = CLIPTextCfg(**text_cfg) - - self.context_length = text_cfg.context_length - # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more - # memory efficient in recent PyTorch releases (>= 1.10). - # NOTE: timm models always use native GELU regardless of quick_gelu flag. +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = PreTrainedTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + ) + else: act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 + if cast_dtype in (torch.float16, torch.bfloat16) + else LayerNorm + ) - if vision_cfg.timm_model_name: - self.visual = TimmModel( - vision_cfg.timm_model_name, - pretrained=vision_cfg.timm_model_pretrained, - pool=vision_cfg.timm_pool, - proj=vision_cfg.timm_proj, - embed_dim=embed_dim, - image_size=vision_cfg.image_size, - ) - act_layer = ( - nn.GELU - ) # so that text transformer doesn't use QuickGELU w/ timm models - elif isinstance(vision_cfg.layers, (tuple, list)): - vision_heads = vision_cfg.width * 32 // vision_cfg.head_width - self.visual = ModifiedResNet( - layers=vision_cfg.layers, - output_dim=embed_dim, - heads=vision_heads, - image_size=vision_cfg.image_size, - width=vision_cfg.width, - ) - else: - vision_heads = vision_cfg.width // vision_cfg.head_width - self.visual = VisualTransformer( - image_size=vision_cfg.image_size, - patch_size=vision_cfg.patch_size, - width=vision_cfg.width, - layers=vision_cfg.layers, - heads=vision_heads, - mlp_ratio=vision_cfg.mlp_ratio, - output_dim=embed_dim, - act_layer=act_layer, - ) - - self.transformer = Transformer( + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, width=text_cfg.width, - layers=text_cfg.layers, heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, act_layer=act_layer, + norm_layer=norm_layer, ) + return text - self.vocab_size = text_cfg.vocab_size - self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) - self.positional_embedding = nn.Parameter( - torch.empty(self.context_length, text_cfg.width) - ) - self.ln_final = LayerNorm(text_cfg.width) - self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim)) +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) - - self.init_parameters() - - def init_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) - - if hasattr(self.visual, 'init_parameters'): - self.visual.init_parameters() - - proj_std = (self.transformer.width**-0.5) * ( - (2 * self.transformer.layers) ** -0.5 - ) - attn_std = self.transformer.width**-0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - @property - def dtype(self): - return self.visual.conv1.weight.dtype + self.register_buffer('attn_mask', text.attn_mask, persistent=False) def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 @@ -587,48 +213,82 @@ def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable - def encode_image(self, image): - return self.visual(image.type(self.dtype)) + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() - def encode_text(self, text): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - x = x + self.positional_embedding.type(self.dtype) + x = x + self.positional_embedding.to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=self.attn_mask) x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x).type(self.dtype) - - # x.shape = [batch_size, n_ctx, transformer.width] + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x + return F.normalize(x, dim=-1) if normalize else x def forward(self, image, text): - if image is None: - return self.encode_text(text) - elif text is None: - return self.encode_image(image) - image_features = self.encode_image(image) - image_features = F.normalize(image_features, dim=-1) + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock( + unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats + ) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) - text_features = self.encode_text(text) - text_features = F.normalize(text_features, dim=-1) + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) return image_features, text_features, self.logit_scale.exp() -def convert_weights_to_fp16(model: nn.Module): - """Convert applicable model parameters to fp16""" +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" - def _convert_weights_to_fp16(l): + def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() + l.weight.data = l.weight.data.to(dtype) if l.bias is not None: - l.bias.data = l.bias.data.half() + l.bias.data = l.bias.data.to(dtype) - if isinstance(l, nn.MultiheadAttention): + if isinstance(l, (nn.MultiheadAttention, Attention)): for attr in [ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", @@ -637,15 +297,18 @@ def _convert_weights_to_fp16(l): ]: tensor = getattr(l, attr) if tensor is not None: - tensor.data = tensor.data.half() + tensor.data = tensor.data.to(dtype) for name in ["text_projection", "proj"]: if hasattr(l, name): attr = getattr(l, name) if attr is not None: - attr.data = attr.data.half() + attr.data = attr.data.to(dtype) - model.apply(_convert_weights_to_fp16) + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat def load_state_dict(checkpoint_path: str, map_location='cpu'): @@ -659,7 +322,11 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): return state_dict -def build_model_from_openai_state_dict(state_dict: dict): +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): vit = "visual.proj" in state_dict if vit: @@ -729,7 +396,8 @@ def build_model_from_openai_state_dict(state_dict: dict): embed_dim, vision_cfg=vision_cfg, text_cfg=text_cfg, - quick_gelu=True, # OpenAI models were trained with QuickGELU + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: @@ -743,18 +411,24 @@ def build_model_from_openai_state_dict(state_dict: dict): def load_openai_model( model_path: str, + precision: Optional[str] = None, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, ): """Load a CLIP model + Parameters ---------- model_path : str The path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. device : Union[str, torch.device] The device to put the loaded model jit : bool Whether to load the optimized JIT model (default) or more hackable non-JIT model. + cache_dir : Optional[str] + The directory to cache the downloaded model weights Returns ------- model : torch.nn.Module @@ -762,6 +436,8 @@ def load_openai_model( preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() @@ -776,16 +452,23 @@ def load_openai_model( state_dict = torch.load(model_path, map_location="cpu") if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) try: model = build_model_from_openai_state_dict( - state_dict or model.state_dict() - ).to(device) + state_dict or model.state_dict(), cast_dtype=cast_dtype + ) except KeyError: sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} - model = build_model_from_openai_state_dict(sd).to(device) + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) - if str(device) == "cpu": + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith('amp') or precision == 'fp32': model.float() + elif precision == 'bf16': + convert_weights_to_lp(model, dtype=torch.bfloat16) + return model # patch the device names @@ -818,8 +501,8 @@ def patch_device(module): patch_device(model.encode_image) patch_device(model.encode_text) - # patch dtype to float32 on CPU - if str(device) == "cpu": + # patch dtype to float32 (typically for CPU) + if precision == 'fp32': float_holder = torch.jit.trace( lambda: torch.ones([]).float(), example_inputs=[] ) @@ -858,9 +541,11 @@ def patch_float(module): def load_openclip_model( model_name: str, model_path: str, + precision: str = 'fp32', device: torch.device = torch.device('cpu'), jit: bool = False, force_quick_gelu: bool = False, + force_custom_text: bool = False, pretrained_image: bool = False, ): model_name = model_name.replace( @@ -885,17 +570,92 @@ def load_openclip_model( False ), 'pretrained image towers currently only supported for timm models' - model = CLIP(**model_cfg) - model.eval() - - model.load_state_dict(load_state_dict(model_path)) + cast_dtype = get_cast_dtype(precision) + custom_text = ( + model_cfg.pop('custom_text', False) + or force_custom_text + or ('hf_model_name' in model_cfg['text_cfg']) + ) - if str(device).startswith('cuda'): - convert_weights_to_fp16(model) + if custom_text: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + print(model_cfg) + model.eval() + model.load_state_dict(load_state_dict(model_path)) model.to(device=device) + if precision in ("fp16", "bf16"): + convert_weights_to_lp( + model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16 + ) + if jit: model = torch.jit.script(model) return model + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros( + (batch_size, model.context_length), dtype=torch.int, device=device + ) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,), + ), + ) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = ( + 1 # FIXME detect different token configs (ie no class token, or more) + ) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = ( + old_pos_embed[:extra_tokens], + old_pos_embed[extra_tokens:], + ) + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info( + 'Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size + ) + pos_emb_img = pos_emb_img.reshape( + 1, old_grid_size[0], old_grid_size[1], -1 + ).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape( + 1, grid_size[0] * grid_size[1], -1 + )[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/server/clip_server/model/modified_resnet.py b/server/clip_server/model/modified_resnet.py new file mode 100644 index 000000000..acba87c83 --- /dev/null +++ b/server/clip_server/model/modified_resnet.py @@ -0,0 +1,212 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from open_clip.utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + # To handle fp16 inference + x = x.type(self.conv1.weight.dtype) + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/server/clip_server/model/openclip_model.py b/server/clip_server/model/openclip_model.py index 9a329f9c5..cc791b36a 100644 --- a/server/clip_server/model/openclip_model.py +++ b/server/clip_server/model/openclip_model.py @@ -24,8 +24,10 @@ def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): self._model_name = model_name - model_url, md5sum = get_model_url_md5(name) - model_path = download_model(model_url, md5sum=md5sum) + # model_url, md5sum = get_model_url_md5(name) + # model_path = download_model(model_url, md5sum=md5sum) + model_path = '/home/zonlin/.cache/clip/test.bin' + pretrained = None if pretrained == 'openai': self._model = load_openai_model(model_path, device=device, jit=jit) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index bc8f18daa..ea3ecf91e 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -70,6 +70,14 @@ 'ViT-g-14-laion2b-s12b-b42k.bin', '3bf99353f6f1829faac0bb155be4382a', ), + 'xlm-roberta-base-ViT-B-32': ( + 'ViT-B-32-laion2b-s12b-b32k.bin', + '76d4c9d13774cc15fa0e2b1b94a8402c', + ), + 'ViT-B-32::laion5b-s13b-b90k': ( + 'ViT-B-32::laion5b-s13b-b90k', + 'f68abc07ef349720f1f880180803142d', + ), # older version name format 'RN50': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), 'RN101': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'), @@ -87,6 +95,8 @@ 'M-CLIP/XLM-Roberta-Large-Vit-L-14': (), 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': (), 'M-CLIP/LABSE-Vit-L-14': (), + 'roberta-base': (), + 'xlm-roberta-base': (), } _VISUAL_MODEL_IMAGE_SIZE = { diff --git a/server/clip_server/model/pretrained_text_encode.py b/server/clip_server/model/pretrained_text_encode.py new file mode 100644 index 000000000..74e236397 --- /dev/null +++ b/server/clip_server/model/pretrained_text_encode.py @@ -0,0 +1,197 @@ +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, + ) +except ImportError as e: + transformers = None + + class BaseModelOutput: + pass + + class PretrainedConfig: + pass + + +# HF architecture dict: +arch_dict = { + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + }, + "pooler": "mean_pooler", + }, +} + + +# utils +def _camel2snake(s): + return re.sub(r'(? TensorType: + attn_mask = (x != self.config.pad_token_id).long() + out = self.transformer(input_ids=x, attention_mask=attn_mask) + pooled_out = self.pooler(out, attn_mask) + + return self.proj(pooled_out) + + def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + if not unlocked_layers: # full freezing + for n, p in self.transformer.named_parameters(): + p.requires_grad = ( + (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + ) + return + + n_layers = ( + len(self.transformer.encoder.layer) - unlocked_layers - 1 + ) # -1 for embeddings + modules = [ + self.transformer.embeddings, + self.transformer.encoder.layer[:n_layers], + ] + for module in modules: + for n, p in module.named_parameters(): + p.requires_grad = ( + (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.gradient_checkpointing_enable() + + def init_parameters(self): + pass diff --git a/server/clip_server/model/transformer.py b/server/clip_server/model/transformer.py new file mode 100644 index 000000000..386010a1c --- /dev/null +++ b/server/clip_server/model/transformer.py @@ -0,0 +1,423 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple + +# Use flash attention +try: + from clip_server.model.flash_attention import MultiheadAttention + + FLASH_ATTENTION_AVAILABLE = True +except: + FLASH_ATTENTION_AVAILABLE = False + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm( + x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps + ) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm( + x.type(torch.float32), + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1.0 / 0.01), + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))) + ) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm( + F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2) + ) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + head_dim = d_model // n_head + self.flash_attention = head_dim % 8 == 0 and head_dim <= 128 + + self.ln_1 = norm_layer(d_model) + self.attn = ( + MultiheadAttention(d_model, n_head) + if FLASH_ATTENTION_AVAILABLE + and torch.cuda.is_available() + and self.flash_attention + else nn.MultiheadAttention(d_model, n_head) + ) + self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)), + ] + ) + ) + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ln_attn(self.attention(self.ln_1(x), attn_mask=attn_mask)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer + ) + for _ in range(layers) + ] + ) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class TextTransformer(nn.Module): + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, width) + ) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + # TODO: adapt this + # ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers) ** -0.5 + ) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = ( + self.image_size[0] // self.patch_size[0], + self.image_size[1] // self.patch_size[1], + ) + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width) + ) + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + # TODO: adapt this + # ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x diff --git a/server/clip_server/model/utils.py b/server/clip_server/model/utils.py new file mode 100644 index 000000000..6afdc4b90 --- /dev/null +++ b/server/clip_server/model/utils.py @@ -0,0 +1,63 @@ +from itertools import repeat +import collections.abc + +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance( + module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) + ): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) diff --git a/server/setup.py b/server/setup.py index 78275b0c1..56fceb645 100644 --- a/server/setup.py +++ b/server/setup.py @@ -48,7 +48,7 @@ 'torchvision', 'jina>=3.8.0', 'prometheus-client', - 'open_clip_torch>=1.3.0', + 'open_clip_torch>=2.5.0', ], extras_require={ 'onnx': [ From 06ec06f5c7189a79d42f9f4b0631ea17ef813999 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Wed, 16 Nov 2022 17:20:53 +0800 Subject: [PATCH 02/33] fix: remove roberta model --- server/clip_server/model/mclip_model.py | 8 +------- server/clip_server/model/model.py | 1 - 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/server/clip_server/model/mclip_model.py b/server/clip_server/model/mclip_model.py index 49e7f6198..c85519501 100644 --- a/server/clip_server/model/mclip_model.py +++ b/server/clip_server/model/mclip_model.py @@ -2,7 +2,6 @@ import transformers import torch -from transformers.models.roberta.modeling_roberta import RobertaModel from clip_server.model.clip_model import CLIPModel from clip_server.model.openclip_model import OpenCLIPModel @@ -12,8 +11,6 @@ 'M-CLIP/XLM-Roberta-Large-Vit-L-14': 'ViT-L-14::openai', 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': 'ViT-B-16-plus-240::laion400m_e31', 'M-CLIP/LABSE-Vit-L-14': 'ViT-L-14::openai', - 'roberta-base': '', - 'xlm-roberta-base': '', } @@ -56,10 +53,7 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwarg class MultilingualCLIPModel(CLIPModel): def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): super().__init__(name, **kwargs) - if name in ('roberta-base', 'xlm-roberta-base'): - self._mclip_model = RobertaModel.from_pretrained(name) - else: - self._mclip_model = MultilingualCLIP.from_pretrained(name) + self._mclip_model = MultilingualCLIP.from_pretrained(name) self._mclip_model.to(device=device) self._mclip_model.eval() self._model = OpenCLIPModel(_CLIP_MODEL_MAPS[name], device=device, jit=jit) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 6cda54607..2ac45ecf3 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -581,7 +581,6 @@ def load_openclip_model( model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) - print(model_cfg) model.eval() model.load_state_dict(load_state_dict(model_path)) From 7fcb81390b77eeb63e6826445eed9d2d9253b1c0 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Wed, 16 Nov 2022 17:29:46 +0800 Subject: [PATCH 03/33] fix: model name --- server/clip_server/model/pretrained_models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index ea3ecf91e..01a211010 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -70,12 +70,12 @@ 'ViT-g-14-laion2b-s12b-b42k.bin', '3bf99353f6f1829faac0bb155be4382a', ), - 'xlm-roberta-base-ViT-B-32': ( + 'roberta-ViT-B-32': ( 'ViT-B-32-laion2b-s12b-b32k.bin', '76d4c9d13774cc15fa0e2b1b94a8402c', ), - 'ViT-B-32::laion5b-s13b-b90k': ( - 'ViT-B-32::laion5b-s13b-b90k', + 'xlm-roberta-base-ViT-B-32': ( + 'ViT-B-32::laion5b-s13b-b90k.bin', 'f68abc07ef349720f1f880180803142d', ), # older version name format @@ -95,8 +95,6 @@ 'M-CLIP/XLM-Roberta-Large-Vit-L-14': (), 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': (), 'M-CLIP/LABSE-Vit-L-14': (), - 'roberta-base': (), - 'xlm-roberta-base': (), } _VISUAL_MODEL_IMAGE_SIZE = { From c4beeca7c001d8d388e6df78dbc141a3352aa3d3 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Wed, 16 Nov 2022 17:34:25 +0800 Subject: [PATCH 04/33] fix: add :: to model name --- server/clip_server/model/openclip_model.py | 1 - server/clip_server/model/pretrained_models.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/server/clip_server/model/openclip_model.py b/server/clip_server/model/openclip_model.py index cc791b36a..5c411d6de 100644 --- a/server/clip_server/model/openclip_model.py +++ b/server/clip_server/model/openclip_model.py @@ -27,7 +27,6 @@ def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): # model_url, md5sum = get_model_url_md5(name) # model_path = download_model(model_url, md5sum=md5sum) model_path = '/home/zonlin/.cache/clip/test.bin' - pretrained = None if pretrained == 'openai': self._model = load_openai_model(model_path, device=device, jit=jit) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index 01a211010..272f9bdaa 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -70,11 +70,11 @@ 'ViT-g-14-laion2b-s12b-b42k.bin', '3bf99353f6f1829faac0bb155be4382a', ), - 'roberta-ViT-B-32': ( + 'roberta-ViT-B-32::laion2b-s12b-b32k': ( 'ViT-B-32-laion2b-s12b-b32k.bin', '76d4c9d13774cc15fa0e2b1b94a8402c', ), - 'xlm-roberta-base-ViT-B-32': ( + 'xlm-roberta-base-ViT-B-32::laion5b-s13b-b90k': ( 'ViT-B-32::laion5b-s13b-b90k.bin', 'f68abc07ef349720f1f880180803142d', ), From 5fbfb571a8b846d6453fe3aa9cb6948fc6f57405 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Wed, 16 Nov 2022 19:09:52 +0800 Subject: [PATCH 05/33] fix: add transformers --- server/clip_server/model/model.py | 2 -- server/clip_server/model/openclip_model.py | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 2ac45ecf3..6177026c3 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -132,8 +132,6 @@ def _build_vision_tower( layers=vision_cfg.layers, heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, - # TODO: adapt this - # ls_init_value=vision_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, diff --git a/server/clip_server/model/openclip_model.py b/server/clip_server/model/openclip_model.py index 5c411d6de..9a329f9c5 100644 --- a/server/clip_server/model/openclip_model.py +++ b/server/clip_server/model/openclip_model.py @@ -24,9 +24,8 @@ def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): self._model_name = model_name - # model_url, md5sum = get_model_url_md5(name) - # model_path = download_model(model_url, md5sum=md5sum) - model_path = '/home/zonlin/.cache/clip/test.bin' + model_url, md5sum = get_model_url_md5(name) + model_path = download_model(model_url, md5sum=md5sum) if pretrained == 'openai': self._model = load_openai_model(model_path, device=device, jit=jit) From d5dd1ce80a3e96a85575031141841ffe4d8afd7b Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Wed, 16 Nov 2022 19:11:51 +0800 Subject: [PATCH 06/33] fix: remove is_init_value --- server/clip_server/model/model.py | 1 - server/clip_server/model/transformer.py | 6 ------ 2 files changed, 7 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 6177026c3..78c736147 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -170,7 +170,6 @@ def _build_text_tower( width=text_cfg.width, heads=text_cfg.heads, layers=text_cfg.layers, - ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, diff --git a/server/clip_server/model/transformer.py b/server/clip_server/model/transformer.py index 386010a1c..4f3879d67 100644 --- a/server/clip_server/model/transformer.py +++ b/server/clip_server/model/transformer.py @@ -232,7 +232,6 @@ def __init__( width: int = 512, heads: int = 8, layers: int = 12, - ls_init_value: float = None, output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, @@ -251,8 +250,6 @@ def __init__( width=width, layers=layers, heads=heads, - # TODO: adapt this - # ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) @@ -320,7 +317,6 @@ def __init__( layers: int, heads: int, mlp_ratio: float, - ls_init_value: float = None, output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, @@ -352,8 +348,6 @@ def __init__( layers, heads, mlp_ratio, - # TODO: adapt this - # ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) From d937f1372b46fb340b418b5dcccfec5773825339 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Wed, 16 Nov 2022 23:08:29 +0800 Subject: [PATCH 07/33] fix: remove transformers --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 87dae3336..949e4c2d0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -163,6 +163,7 @@ jobs: } pip install -e "client/[test]" pip install -e "server/[tensorrt]" + pip install -e "server/[transformers]" - name: Test id: test run: | From fe2745ba083781065ca1ac1c9e421373c3585e04 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Thu, 17 Nov 2022 13:59:41 +0800 Subject: [PATCH 08/33] fix: not use flash-attn on cpu --- server/clip_server/model/model.py | 14 ++++++++------ server/clip_server/model/transformer.py | 14 +++++++++++++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 78c736147..d18a39acf 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -79,6 +79,8 @@ def get_cast_dtype(precision: str): cast_dtype = torch.bfloat16 elif precision == 'fp16': cast_dtype = torch.float16 + elif precision == 'fp32': + cast_dtype = torch.float32 return cast_dtype @@ -135,6 +137,7 @@ def _build_vision_tower( output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, + cast_dtype=cast_dtype, ) return visual @@ -173,6 +176,7 @@ def _build_text_tower( output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, + cast_dtype=cast_dtype, ) return text @@ -408,8 +412,7 @@ def build_model_from_openai_state_dict( def load_openai_model( model_path: str, - precision: Optional[str] = None, - device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu', jit=True, ): """Load a CLIP model @@ -433,8 +436,7 @@ def load_openai_model( preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ - if precision is None: - precision = 'fp32' if device == 'cpu' else 'fp16' + precision = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() @@ -538,13 +540,13 @@ def patch_float(module): def load_openclip_model( model_name: str, model_path: str, - precision: str = 'fp32', - device: torch.device = torch.device('cpu'), + device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, pretrained_image: bool = False, ): + precision = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' model_name = model_name.replace( '/', '-' ) # for callers using old naming with / in ViT names diff --git a/server/clip_server/model/transformer.py b/server/clip_server/model/transformer.py index 4f3879d67..b7d5a09d4 100644 --- a/server/clip_server/model/transformer.py +++ b/server/clip_server/model/transformer.py @@ -150,6 +150,7 @@ def __init__( norm_layer: Callable = LayerNorm, scale_attn: bool = False, scale_fc: bool = False, + cast_dtype=torch.float32, ): super().__init__() head_dim = d_model // n_head @@ -160,6 +161,7 @@ def __init__( MultiheadAttention(d_model, n_head) if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available() + and cast_dtype in (torch.float16, torch.bfloat16) and self.flash_attention else nn.MultiheadAttention(d_model, n_head) ) @@ -197,6 +199,7 @@ def __init__( mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + cast_dtype=torch.float32, ): super().__init__() self.width = width @@ -206,7 +209,12 @@ def __init__( self.resblocks = nn.ModuleList( [ ResidualAttentionBlock( - width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer + width, + heads, + mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + cast_dtype=cast_dtype, ) for _ in range(layers) ] @@ -235,6 +243,7 @@ def __init__( output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + cast_dtype=torch.float16, ): super().__init__() self.context_length = context_length @@ -252,6 +261,7 @@ def __init__( heads=heads, act_layer=act_layer, norm_layer=norm_layer, + cast_dtype=cast_dtype, ) self.ln_final = norm_layer(width) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) @@ -320,6 +330,7 @@ def __init__( output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + cast_dtype=torch.float32, ): super().__init__() self.image_size = to_2tuple(image_size) @@ -350,6 +361,7 @@ def __init__( mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, + cast_dtype=cast_dtype, ) self.ln_post = norm_layer(width) From 7ddd51b4490d291da9d7452d734c958be63f141f Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Thu, 17 Nov 2022 14:06:59 +0800 Subject: [PATCH 09/33] fix: add assert description --- server/clip_server/model/flash_attention.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/server/clip_server/model/flash_attention.py b/server/clip_server/model/flash_attention.py index fe368f33e..a5ed6c13e 100644 --- a/server/clip_server/model/flash_attention.py +++ b/server/clip_server/model/flash_attention.py @@ -57,9 +57,12 @@ def attention( key_padding_mask: a bool tensor of shape (B, S) """ - assert not need_weights - assert q.dtype in [torch.float16, torch.bfloat16] - assert q.is_cuda + assert not need_weights, "not allowed to return weights." + assert q.dtype in [ + torch.float16, + torch.bfloat16, + ], f"flash attention only support torch.float16 or torch.bfloat16 but got {q.dtype}." + assert q.is_cuda, "flash attention only support cuda." if cu_seqlens is None: max_s = seqlen From cce60df1fa76767df6af62e16b46103a247f716d Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Thu, 17 Nov 2022 17:16:48 +0800 Subject: [PATCH 10/33] fix: housekeeping --- server/clip_server/model/model.py | 2 +- server/clip_server/model/modified_resnet.py | 202 +----------------- .../model/pretrained_text_encode.py | 198 +---------------- server/clip_server/model/utils.py | 64 +----- 4 files changed, 5 insertions(+), 461 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 78c736147..3424995d8 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -22,7 +22,7 @@ from open_clip.timm_model import TimmModel from open_clip.factory import _MODEL_CONFIGS -from .modified_resnet import ModifiedResNet +from .modified_resnet import CasModifiedResNet as ModifiedResNet from .transformer import ( LayerNormFp32, LayerNorm, diff --git a/server/clip_server/model/modified_resnet.py b/server/clip_server/model/modified_resnet.py index acba87c83..057ceb9d5 100644 --- a/server/clip_server/model/modified_resnet.py +++ b/server/clip_server/model/modified_resnet.py @@ -1,204 +1,7 @@ -from collections import OrderedDict +from open_clip.modified_resnet import ModifiedResNet -import torch -from torch import nn -from torch.nn import functional as F - -from open_clip.utils import freeze_batch_norm_2d - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.act1 = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.act2 = nn.ReLU(inplace=True) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.act3 = nn.ReLU(inplace=True) - - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential( - OrderedDict( - [ - ("-1", nn.AvgPool2d(stride)), - ( - "0", - nn.Conv2d( - inplanes, - planes * self.expansion, - 1, - stride=1, - bias=False, - ), - ), - ("1", nn.BatchNorm2d(planes * self.expansion)), - ] - ) - ) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.act1(self.bn1(self.conv1(x))) - out = self.act2(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.act3(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__( - self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None - ): - super().__init__() - self.positional_embedding = nn.Parameter( - torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 - ) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( - 2, 0, 1 - ) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat( - [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] - ), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - - return x[0] - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, image_size=224, width=64): - super().__init__() - self.output_dim = output_dim - self.image_size = image_size - - # the 3-layer stem - self.conv1 = nn.Conv2d( - 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False - ) - self.bn1 = nn.BatchNorm2d(width // 2) - self.act1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d( - width // 2, width // 2, kernel_size=3, padding=1, bias=False - ) - self.bn2 = nn.BatchNorm2d(width // 2) - self.act2 = nn.ReLU(inplace=True) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.act3 = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(2) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) - - self.init_parameters() - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def init_parameters(self): - if self.attnpool is not None: - std = self.attnpool.c_proj.in_features**-0.5 - nn.init.normal_(self.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert ( - unlocked_groups == 0 - ), 'partial locking not currently supported for this model' - for param in self.parameters(): - param.requires_grad = False - if freeze_bn_stats: - freeze_batch_norm_2d(self) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - # FIXME support for non-transformer - pass - - def stem(self, x): - x = self.act1(self.bn1(self.conv1(x))) - x = self.act2(self.bn2(self.conv2(x))) - x = self.act3(self.bn3(self.conv3(x))) - x = self.avgpool(x) - return x +class CasModifiedResNet(ModifiedResNet): def forward(self, x): # To handle fp16 inference x = x.type(self.conv1.weight.dtype) @@ -208,5 +11,4 @@ def forward(self, x): x = self.layer3(x) x = self.layer4(x) x = self.attnpool(x) - return x diff --git a/server/clip_server/model/pretrained_text_encode.py b/server/clip_server/model/pretrained_text_encode.py index 74e236397..2392b889a 100644 --- a/server/clip_server/model/pretrained_text_encode.py +++ b/server/clip_server/model/pretrained_text_encode.py @@ -1,197 +1 @@ -import re - -import torch -import torch.nn as nn -from torch import TensorType - -try: - import transformers - from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig - from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPooling, - BaseModelOutputWithPoolingAndCrossAttentions, - ) -except ImportError as e: - transformers = None - - class BaseModelOutput: - pass - - class PretrainedConfig: - pass - - -# HF architecture dict: -arch_dict = { - "roberta": { - "config_names": { - "context_length": "max_position_embeddings", - "vocab_size": "vocab_size", - "width": "hidden_size", - "heads": "num_attention_heads", - "layers": "num_hidden_layers", - }, - "pooler": "mean_pooler", - }, - # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig - "xlm-roberta": { - "config_names": { - "context_length": "max_position_embeddings", - "vocab_size": "vocab_size", - "width": "hidden_size", - "heads": "num_attention_heads", - "layers": "num_hidden_layers", - }, - "pooler": "mean_pooler", - }, -} - - -# utils -def _camel2snake(s): - return re.sub(r'(? TensorType: - attn_mask = (x != self.config.pad_token_id).long() - out = self.transformer(input_ids=x, attention_mask=attn_mask) - pooled_out = self.pooler(out, attn_mask) - - return self.proj(pooled_out) - - def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): - if not unlocked_layers: # full freezing - for n, p in self.transformer.named_parameters(): - p.requires_grad = ( - (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False - ) - return - - n_layers = ( - len(self.transformer.encoder.layer) - unlocked_layers - 1 - ) # -1 for embeddings - modules = [ - self.transformer.embeddings, - self.transformer.encoder.layer[:n_layers], - ] - for module in modules: - for n, p in module.named_parameters(): - p.requires_grad = ( - (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False - ) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.transformer.gradient_checkpointing_enable() - - def init_parameters(self): - pass +from open_clip.hf_model import PreTrainedTextEncoder diff --git a/server/clip_server/model/utils.py b/server/clip_server/model/utils.py index 6afdc4b90..c53928eeb 100644 --- a/server/clip_server/model/utils.py +++ b/server/clip_server/model/utils.py @@ -1,63 +1 @@ -from itertools import repeat -import collections.abc - -from torch import nn as nn -from torchvision.ops.misc import FrozenBatchNorm2d - - -def freeze_batch_norm_2d(module, module_match={}, name=''): - """ - Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is - itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and - returned. Otherwise, the module is walked recursively and submodules are converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. - module_match (dict): Dictionary of full module names to freeze (all if empty) - name (str): Full module name (prefix) - - Returns: - torch.nn.Module: Resulting module - - Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 - """ - res = module - is_match = True - if module_match: - is_match = name in module_match - if is_match and isinstance( - module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) - ): - res = FrozenBatchNorm2d(module.num_features) - res.num_features = module.num_features - res.affine = module.affine - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for child_name, child in module.named_children(): - full_child_name = '.'.join([name, child_name]) if name else child_name - new_child = freeze_batch_norm_2d(child, module_match, full_child_name) - if new_child is not child: - res.add_module(child_name, new_child) - return res - - -# From PyTorch internals -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -to_1tuple = _ntuple(1) -to_2tuple = _ntuple(2) -to_3tuple = _ntuple(3) -to_4tuple = _ntuple(4) -to_ntuple = lambda n, x: _ntuple(n)(x) +from open_clip.utils import to_2tuple From a93d7387705f22d90c467e56d1fb76e4d5d44f70 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Thu, 17 Nov 2022 17:36:24 +0800 Subject: [PATCH 11/33] fix: cleanup --- server/clip_server/model/model.py | 35 +++-- server/clip_server/model/modified_resnet.py | 14 -- .../model/pretrained_text_encode.py | 1 - server/clip_server/model/transformer.py | 126 +----------------- server/clip_server/model/utils.py | 1 - 5 files changed, 23 insertions(+), 154 deletions(-) delete mode 100644 server/clip_server/model/modified_resnet.py delete mode 100644 server/clip_server/model/pretrained_text_encode.py delete mode 100644 server/clip_server/model/utils.py diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 201c2dfa5..d59641cb3 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -11,28 +11,35 @@ import logging import math import warnings +import numpy as np from dataclasses import dataclass from typing import Tuple, Union, Optional from copy import deepcopy -import numpy as np import torch -import torch.nn.functional as F from torch import nn +import torch.nn.functional as F +from clip_server.model.transformer import TextTransformer, VisionTransformer +from open_clip.transformer import QuickGELU, LayerNorm, LayerNormFp32, Attention from open_clip.timm_model import TimmModel from open_clip.factory import _MODEL_CONFIGS -from .modified_resnet import CasModifiedResNet as ModifiedResNet -from .transformer import ( - LayerNormFp32, - LayerNorm, - QuickGELU, - TextTransformer, - VisionTransformer, - Attention, -) -from .utils import to_2tuple -from .pretrained_text_encode import PreTrainedTextEncoder +from open_clip.hf_model import PreTrainedTextEncoder +from open_clip.modified_resnet import ModifiedResNet +from open_clip.utils import to_2tuple + + +class CasModifiedResNet(ModifiedResNet): + def forward(self, x): + # To handle fp16 inference + x = x.type(self.conv1.weight.dtype) + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + return x @dataclass @@ -113,7 +120,7 @@ def _build_vision_tower( ) # so that text transformer doesn't use QuickGELU w/ timm models elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width - visual = ModifiedResNet( + visual = CasModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, diff --git a/server/clip_server/model/modified_resnet.py b/server/clip_server/model/modified_resnet.py deleted file mode 100644 index 057ceb9d5..000000000 --- a/server/clip_server/model/modified_resnet.py +++ /dev/null @@ -1,14 +0,0 @@ -from open_clip.modified_resnet import ModifiedResNet - - -class CasModifiedResNet(ModifiedResNet): - def forward(self, x): - # To handle fp16 inference - x = x.type(self.conv1.weight.dtype) - x = self.stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - return x diff --git a/server/clip_server/model/pretrained_text_encode.py b/server/clip_server/model/pretrained_text_encode.py deleted file mode 100644 index 2392b889a..000000000 --- a/server/clip_server/model/pretrained_text_encode.py +++ /dev/null @@ -1 +0,0 @@ -from open_clip.hf_model import PreTrainedTextEncoder diff --git a/server/clip_server/model/transformer.py b/server/clip_server/model/transformer.py index b7d5a09d4..49b0f1c31 100644 --- a/server/clip_server/model/transformer.py +++ b/server/clip_server/model/transformer.py @@ -1,5 +1,4 @@ from collections import OrderedDict -import math from typing import Callable, Optional import torch @@ -7,7 +6,8 @@ from torch.nn import functional as F from torch.utils.checkpoint import checkpoint -from .utils import to_2tuple +from open_clip.utils import to_2tuple +from open_clip.transformer import LayerNorm # Use flash attention try: @@ -18,128 +18,6 @@ FLASH_ATTENTION_AVAILABLE = False -class LayerNormFp32(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - x = F.layer_norm( - x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps - ) - return x.to(orig_type) - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - x = F.layer_norm( - x.type(torch.float32), - self.normalized_shape, - self.weight, - self.bias, - self.eps, - ) - return x.to(orig_type) - - -class QuickGELU(nn.Module): - # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - -class Attention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - qkv_bias=True, - scaled_cosine=False, - scale_heads=False, - logit_scale_max=math.log(1.0 / 0.01), - attn_drop=0.0, - proj_drop=0.0, - ): - super().__init__() - self.scaled_cosine = scaled_cosine - self.scale_heads = scale_heads - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim**-0.5 - self.logit_scale_max = logit_scale_max - - # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original - self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) - if qkv_bias: - self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) - else: - self.in_proj_bias = None - - if self.scaled_cosine: - self.logit_scale = nn.Parameter( - torch.log(10 * torch.ones((num_heads, 1, 1))) - ) - else: - self.logit_scale = None - self.attn_drop = nn.Dropout(attn_drop) - if self.scale_heads: - self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) - else: - self.head_scale = None - self.out_proj = nn.Linear(dim, dim) - self.out_drop = nn.Dropout(proj_drop) - - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): - L, N, C = x.shape - q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) - q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - - if self.logit_scale is not None: - attn = torch.bmm( - F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2) - ) - logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() - attn = attn.view(N, self.num_heads, L, L) * logit_scale - attn = attn.view(-1, L, L) - else: - q = q * self.scale - attn = torch.bmm(q, k.transpose(-1, -2)) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, float("-inf")) - attn_mask = new_attn_mask - attn += attn_mask - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = torch.bmm(attn, v) - if self.head_scale is not None: - x = x.view(N, self.num_heads, L, C) * self.head_scale - x = x.view(-1, L, C) - x = x.transpose(0, 1).reshape(L, N, C) - x = self.out_proj(x) - x = self.out_drop(x) - return x - - class ResidualAttentionBlock(nn.Module): def __init__( self, diff --git a/server/clip_server/model/utils.py b/server/clip_server/model/utils.py deleted file mode 100644 index c53928eeb..000000000 --- a/server/clip_server/model/utils.py +++ /dev/null @@ -1 +0,0 @@ -from open_clip.utils import to_2tuple From cc3514d87dcb2ec28ae818e7e685e72da788b20c Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Thu, 17 Nov 2022 17:53:53 +0800 Subject: [PATCH 12/33] fix: cleanup --- server/clip_server/model/model.py | 63 ------------------------------- 1 file changed, 63 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index d59641cb3..78949c7de 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -601,66 +601,3 @@ def load_openclip_model( model = torch.jit.script(model) return model - - -def trace_model(model, batch_size=256, device=torch.device('cpu')): - model.eval() - image_size = model.visual.image_size - example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) - example_text = torch.zeros( - (batch_size, model.context_length), dtype=torch.int, device=device - ) - model = torch.jit.trace_module( - model, - inputs=dict( - forward=(example_images, example_text), - encode_text=(example_text,), - encode_image=(example_images,), - ), - ) - model.visual.image_size = image_size - return model - - -def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): - # Rescale the grid of position embeddings when loading from state_dict - old_pos_embed = state_dict.get('visual.positional_embedding', None) - if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): - return - grid_size = to_2tuple(model.visual.grid_size) - extra_tokens = ( - 1 # FIXME detect different token configs (ie no class token, or more) - ) - new_seq_len = grid_size[0] * grid_size[1] + extra_tokens - if new_seq_len == old_pos_embed.shape[0]: - return - - if extra_tokens: - pos_emb_tok, pos_emb_img = ( - old_pos_embed[:extra_tokens], - old_pos_embed[extra_tokens:], - ) - else: - pos_emb_tok, pos_emb_img = None, old_pos_embed - old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) - - logging.info( - 'Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size - ) - pos_emb_img = pos_emb_img.reshape( - 1, old_grid_size[0], old_grid_size[1], -1 - ).permute(0, 3, 1, 2) - pos_emb_img = F.interpolate( - pos_emb_img, - size=grid_size, - mode=interpolation, - align_corners=True, - ) - pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape( - 1, grid_size[0] * grid_size[1], -1 - )[0] - if pos_emb_tok is not None: - new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) - else: - new_pos_embed = pos_emb_img - state_dict['visual.positional_embedding'] = new_pos_embed From 1e94a6e96ae58511f5461a8c1ce2100f1875ffb9 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Fri, 18 Nov 2022 11:29:08 +0800 Subject: [PATCH 13/33] fix: allow to set precision --- server/clip_server/model/model.py | 13 ++++++------- server/clip_server/model/openclip_model.py | 19 ++++++++++++++++--- server/clip_server/model/transformer.py | 8 ++++++++ 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 78949c7de..6fca93ba3 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -8,8 +8,6 @@ Ludwig Schmidt """ -import logging -import math import warnings import numpy as np from dataclasses import dataclass @@ -26,7 +24,6 @@ from open_clip.factory import _MODEL_CONFIGS from open_clip.hf_model import PreTrainedTextEncoder from open_clip.modified_resnet import ModifiedResNet -from open_clip.utils import to_2tuple class CasModifiedResNet(ModifiedResNet): @@ -420,6 +417,7 @@ def build_model_from_openai_state_dict( def load_openai_model( model_path: str, device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu', + precision: str = None, jit=True, ): """Load a CLIP model @@ -434,8 +432,6 @@ def load_openai_model( The device to put the loaded model jit : bool Whether to load the optimized JIT model (default) or more hackable non-JIT model. - cache_dir : Optional[str] - The directory to cache the downloaded model weights Returns ------- model : torch.nn.Module @@ -443,7 +439,8 @@ def load_openai_model( preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ - precision = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' + if precision is None: + precision = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() @@ -549,11 +546,13 @@ def load_openclip_model( model_path: str, device: Union[str, torch.device] = 'cpu', jit: bool = False, + precision: str = None, force_quick_gelu: bool = False, force_custom_text: bool = False, pretrained_image: bool = False, ): - precision = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' + if precision is None: + precision = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' model_name = model_name.replace( '/', '-' ) # for callers using old naming with / in ViT names diff --git a/server/clip_server/model/openclip_model.py b/server/clip_server/model/openclip_model.py index 9a329f9c5..f757a7b13 100644 --- a/server/clip_server/model/openclip_model.py +++ b/server/clip_server/model/openclip_model.py @@ -13,7 +13,14 @@ class OpenCLIPModel(CLIPModel): - def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): + def __init__( + self, + name: str, + device: str = 'cpu', + jit: bool = False, + precision: str = None, + **kwargs + ): super().__init__(name, **kwargs) if '::' in name: @@ -28,10 +35,16 @@ def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): model_path = download_model(model_url, md5sum=md5sum) if pretrained == 'openai': - self._model = load_openai_model(model_path, device=device, jit=jit) + self._model = load_openai_model( + model_path, device=device, jit=jit, precision=precision + ) else: self._model = load_openclip_model( - self._model_name, model_path=model_path, device=device, jit=jit + self._model_name, + model_path=model_path, + device=device, + jit=jit, + precision=precision, ) @staticmethod diff --git a/server/clip_server/model/transformer.py b/server/clip_server/model/transformer.py index 49b0f1c31..e817afa64 100644 --- a/server/clip_server/model/transformer.py +++ b/server/clip_server/model/transformer.py @@ -1,3 +1,11 @@ +# Originally from https://github.com/mlfoundations/open_clip. +# +# Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, +# Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, +# John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, +# Ludwig Schmidt + + from collections import OrderedDict from typing import Callable, Optional From 0d2f871911b4ee751c1aa54d0d8f7e2d0567f720 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Fri, 18 Nov 2022 11:33:15 +0800 Subject: [PATCH 14/33] fix: gpu-test --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 949e4c2d0..e331c31be 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -113,6 +113,7 @@ jobs: pip install --no-cache-dir "server/[onnx]" pip install --no-cache-dir "server/[transformers]" pip install --no-cache-dir "server/[search]" + pip install --no-cache-dir "server/[transformers]" - name: Test id: test run: | @@ -163,7 +164,6 @@ jobs: } pip install -e "client/[test]" pip install -e "server/[tensorrt]" - pip install -e "server/[transformers]" - name: Test id: test run: | From 5b0a65a059417e5e80e637349893297855533aba Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Fri, 18 Nov 2022 11:48:46 +0800 Subject: [PATCH 15/33] fix: add roberta model test --- tests/test_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model.py b/tests/test_model.py index 29333b974..053a1d2b0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -10,6 +10,7 @@ [ ('ViT-L/14@336px', OpenCLIPModel), ('RN50::openai', OpenCLIPModel), + ('roberta-ViT-B-32::laion2b-s12b-b32k', OpenCLIPModel), ('M-CLIP/LABSE-Vit-L-14', MultilingualCLIPModel), ], ) From 43b7a3187a3f081310657b7af85dc75f310c5625 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Fri, 18 Nov 2022 12:36:02 +0800 Subject: [PATCH 16/33] fix: dtype --- server/clip_server/model/model.py | 52 ++++++++++++++----------------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 6fca93ba3..59826a5a6 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -23,20 +23,14 @@ from open_clip.timm_model import TimmModel from open_clip.factory import _MODEL_CONFIGS from open_clip.hf_model import PreTrainedTextEncoder -from open_clip.modified_resnet import ModifiedResNet +from open_clip.modified_resnet import ModifiedResNet as _ModifiedResNet -class CasModifiedResNet(ModifiedResNet): +class ModifiedResNet(_ModifiedResNet): def forward(self, x): # To handle fp16 inference x = x.type(self.conv1.weight.dtype) - x = self.stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - return x + return super().forward(x) @dataclass @@ -77,13 +71,13 @@ class CLIPTextCfg: pooler_type: str = 'mean_pooler' -def get_cast_dtype(precision: str): +def get_cast_dtype(dtype: str): cast_dtype = None - if precision == 'bf16': + if dtype == 'bf16': cast_dtype = torch.bfloat16 - elif precision == 'fp16': + elif dtype == 'fp16': cast_dtype = torch.float16 - elif precision == 'fp32': + elif dtype == 'fp32': cast_dtype = torch.float32 return cast_dtype @@ -117,7 +111,7 @@ def _build_vision_tower( ) # so that text transformer doesn't use QuickGELU w/ timm models elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width - visual = CasModifiedResNet( + visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, @@ -417,7 +411,7 @@ def build_model_from_openai_state_dict( def load_openai_model( model_path: str, device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu', - precision: str = None, + dtype: str = None, jit=True, ): """Load a CLIP model @@ -426,7 +420,7 @@ def load_openai_model( ---------- model_path : str The path to a model checkpoint containing the state_dict - precision: str + dtype: str Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. device : Union[str, torch.device] The device to put the loaded model @@ -439,8 +433,8 @@ def load_openai_model( preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ - if precision is None: - precision = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' + if dtype is None: + dtype = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() @@ -456,7 +450,7 @@ def load_openai_model( if not jit: # Build a non-jit model from the OpenAI jitted model state dict - cast_dtype = get_cast_dtype(precision) + cast_dtype = get_cast_dtype(dtype) try: model = build_model_from_openai_state_dict( state_dict or model.state_dict(), cast_dtype=cast_dtype @@ -467,9 +461,9 @@ def load_openai_model( # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use model = model.to(device) - if precision.startswith('amp') or precision == 'fp32': + if dtype.startswith('amp') or dtype == 'fp32': model.float() - elif precision == 'bf16': + elif dtype == 'bf16': convert_weights_to_lp(model, dtype=torch.bfloat16) return model @@ -505,7 +499,7 @@ def patch_device(module): patch_device(model.encode_text) # patch dtype to float32 (typically for CPU) - if precision == 'fp32': + if dtype == 'fp32': float_holder = torch.jit.trace( lambda: torch.ones([]).float(), example_inputs=[] ) @@ -544,15 +538,15 @@ def patch_float(module): def load_openclip_model( model_name: str, model_path: str, - device: Union[str, torch.device] = 'cpu', + device: Optional(Union[str, torch.device]) = 'cpu', jit: bool = False, - precision: str = None, + dtype: str = None, force_quick_gelu: bool = False, force_custom_text: bool = False, pretrained_image: bool = False, ): - if precision is None: - precision = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' + if dtype is None: + dtype = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' model_name = model_name.replace( '/', '-' ) # for callers using old naming with / in ViT names @@ -575,7 +569,7 @@ def load_openclip_model( False ), 'pretrained image towers currently only supported for timm models' - cast_dtype = get_cast_dtype(precision) + cast_dtype = get_cast_dtype(dtype) custom_text = ( model_cfg.pop('custom_text', False) or force_custom_text @@ -591,9 +585,9 @@ def load_openclip_model( model.load_state_dict(load_state_dict(model_path)) model.to(device=device) - if precision in ("fp16", "bf16"): + if dtype in ("fp16", "bf16"): convert_weights_to_lp( - model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16 + model, dtype=torch.bfloat16 if dtype == 'bf16' else torch.float16 ) if jit: From 9efb369d91dce42d483c020547a531c2b5f5e16a Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Fri, 18 Nov 2022 12:40:49 +0800 Subject: [PATCH 17/33] fix: dtype --- server/clip_server/model/openclip_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/clip_server/model/openclip_model.py b/server/clip_server/model/openclip_model.py index f757a7b13..5d15be9f2 100644 --- a/server/clip_server/model/openclip_model.py +++ b/server/clip_server/model/openclip_model.py @@ -18,7 +18,7 @@ def __init__( name: str, device: str = 'cpu', jit: bool = False, - precision: str = None, + dtype: str = None, **kwargs ): super().__init__(name, **kwargs) @@ -36,7 +36,7 @@ def __init__( if pretrained == 'openai': self._model = load_openai_model( - model_path, device=device, jit=jit, precision=precision + model_path, device=device, jit=jit, dtype=dtype ) else: self._model = load_openclip_model( @@ -44,7 +44,7 @@ def __init__( model_path=model_path, device=device, jit=jit, - precision=precision, + dtype=dtype, ) @staticmethod From 848f0aa4118b2a96fc24e1df9ac104c74aa36eb6 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Fri, 18 Nov 2022 13:13:14 +0800 Subject: [PATCH 18/33] fix: remove optional --- server/clip_server/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 59826a5a6..89a24e15e 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -538,7 +538,7 @@ def patch_float(module): def load_openclip_model( model_name: str, model_path: str, - device: Optional(Union[str, torch.device]) = 'cpu', + device: Union[str, torch.device] = 'cpu', jit: bool = False, dtype: str = None, force_quick_gelu: bool = False, From b2472281bdf8a3d49755687cc24a5ebd88c55a13 Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Fri, 18 Nov 2022 18:40:07 +0800 Subject: [PATCH 19/33] fix: housekeeping --- server/clip_server/model/model.py | 168 +++++++----------------- server/clip_server/model/transformer.py | 16 +-- tests/test_model.py | 2 +- 3 files changed, 55 insertions(+), 131 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 89a24e15e..1028e48dd 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -24,7 +24,8 @@ from open_clip.factory import _MODEL_CONFIGS from open_clip.hf_model import PreTrainedTextEncoder from open_clip.modified_resnet import ModifiedResNet as _ModifiedResNet - +from open_clip.model import CustomTextCLIP as _CustomTextCLIP +from open_clip.model import CLIP as _CLIP class ModifiedResNet(_ModifiedResNet): def forward(self, x): @@ -71,22 +72,11 @@ class CLIPTextCfg: pooler_type: str = 'mean_pooler' -def get_cast_dtype(dtype: str): - cast_dtype = None - if dtype == 'bf16': - cast_dtype = torch.bfloat16 - elif dtype == 'fp16': - cast_dtype = torch.float16 - elif dtype == 'fp32': - cast_dtype = torch.float32 - return cast_dtype - - def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, + dtype: Optional[torch.dtype] = torch.float32, ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) @@ -122,7 +112,7 @@ def _build_vision_tower( vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = ( LayerNormFp32 - if cast_dtype in (torch.float16, torch.bfloat16) + if dtype in (torch.float16, torch.bfloat16) else LayerNorm ) visual = VisionTransformer( @@ -135,7 +125,7 @@ def _build_vision_tower( output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, - cast_dtype=cast_dtype, + dtype=dtype, ) return visual @@ -145,7 +135,7 @@ def _build_text_tower( embed_dim: int, text_cfg: CLIPTextCfg, quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, + dtype: Optional[torch.dtype] = torch.float32, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) @@ -161,7 +151,7 @@ def _build_text_tower( act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( LayerNormFp32 - if cast_dtype in (torch.float16, torch.bfloat16) + if dtype in (torch.float16, torch.bfloat16) else LayerNorm ) @@ -174,113 +164,48 @@ def _build_text_tower( output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, - cast_dtype=cast_dtype, + dtype=dtype, ) return text -class CLIP(nn.Module): +class CustomTextCLIP(_CustomTextCLIP): def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + dtype: Optional[torch.dtype] = torch.float32, ): - super().__init__() - self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + super().__init__(embed_dim, vision_cfg, text_cfg, quick_gelu, dtype) + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, dtype) + - text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) +class CLIP(_CLIP): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + dtype: Optional[torch.dtype] = torch.float32, + ): + super().__init__(embed_dim, vision_cfg, text_cfg, quick_gelu, dtype) + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, dtype) + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, dtype) self.transformer = text.transformer self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.register_buffer('attn_mask', text.attn_mask, persistent=False) - def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): - # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 - self.visual.lock( - unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats - ) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.visual.set_grad_checkpointing(enable) - self.transformer.grad_checkpointing = enable - - def encode_image(self, image, normalize: bool = False): - features = self.visual(image) - return F.normalize(features, dim=-1) if normalize else features - - def encode_text(self, text, normalize: bool = False): - cast_dtype = self.transformer.get_cast_dtype() - - x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.to(cast_dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - return F.normalize(x, dim=-1) if normalize else x - - def forward(self, image, text): - image_features = self.encode_image(image, normalize=True) - text_features = self.encode_text(text, normalize=True) - return image_features, text_features, self.logit_scale.exp() - - -class CustomTextCLIP(nn.Module): - def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, - ): - super().__init__() - self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) - self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): - # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 - self.visual.lock( - unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats - ) - - def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): - self.text.lock(unlocked_layers, freeze_layer_norm) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.visual.set_grad_checkpointing(enable) - self.text.set_grad_checkpointing(enable) - - def encode_image(self, image, normalize: bool = False): - features = self.visual(image) - return F.normalize(features, dim=-1) if normalize else features - - def encode_text(self, text, normalize: bool = False): - features = self.text(text) - return F.normalize(features, dim=-1) if normalize else features - - def forward(self, image, text): - image_features = self.encode_image(image, normalize=True) - text_features = self.encode_text(text, normalize=True) - return image_features, text_features, self.logit_scale.exp() - def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): """Convert applicable model parameters to low-precision (bf16 or fp16)""" - def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.to(dtype) @@ -324,7 +249,7 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, - cast_dtype=torch.float16, + dtype=torch.float32, ): vit = "visual.proj" in state_dict @@ -396,7 +321,7 @@ def build_model_from_openai_state_dict( vision_cfg=vision_cfg, text_cfg=text_cfg, quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU - cast_dtype=cast_dtype, + dtype=dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: @@ -411,7 +336,7 @@ def build_model_from_openai_state_dict( def load_openai_model( model_path: str, device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu', - dtype: str = None, + dtype: Optional[Union[str, torch.dtype]] = None, jit=True, ): """Load a CLIP model @@ -434,7 +359,7 @@ def load_openai_model( A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if dtype is None: - dtype = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' + dtype = torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16 try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() @@ -450,20 +375,19 @@ def load_openai_model( if not jit: # Build a non-jit model from the OpenAI jitted model state dict - cast_dtype = get_cast_dtype(dtype) try: model = build_model_from_openai_state_dict( - state_dict or model.state_dict(), cast_dtype=cast_dtype + state_dict or model.state_dict(), dtype=dtype ) except KeyError: sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} - model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + model = build_model_from_openai_state_dict(sd, dtype=dtype) # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use model = model.to(device) - if dtype.startswith('amp') or dtype == 'fp32': + if dtype == torch.float32 or dtype.startswith('amp'): model.float() - elif dtype == 'bf16': + elif dtype == torch.bfloat16: convert_weights_to_lp(model, dtype=torch.bfloat16) return model @@ -540,13 +464,14 @@ def load_openclip_model( model_path: str, device: Union[str, torch.device] = 'cpu', jit: bool = False, - dtype: str = None, + dtype: Optional[Union[str, torch.dtype]] = None, force_quick_gelu: bool = False, force_custom_text: bool = False, pretrained_image: bool = False, ): if dtype is None: - dtype = 'fp32' if device in ('cpu', torch.device('cpu')) else 'fp16' + dtype = torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16 + model_name = model_name.replace( '/', '-' ) # for callers using old naming with / in ViT names @@ -569,7 +494,6 @@ def load_openclip_model( False ), 'pretrained image towers currently only supported for timm models' - cast_dtype = get_cast_dtype(dtype) custom_text = ( model_cfg.pop('custom_text', False) or force_custom_text @@ -577,17 +501,17 @@ def load_openclip_model( ) if custom_text: - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + model = CustomTextCLIP(**model_cfg, dtype=dtype) else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) + model = CLIP(**model_cfg, dtype=dtype) model.eval() model.load_state_dict(load_state_dict(model_path)) model.to(device=device) - if dtype in ("fp16", "bf16"): + if dtype in (torch.float16, torch.bfloat16): convert_weights_to_lp( - model, dtype=torch.bfloat16 if dtype == 'bf16' else torch.float16 + model, dtype=dtype ) if jit: diff --git a/server/clip_server/model/transformer.py b/server/clip_server/model/transformer.py index e817afa64..14ab4f114 100644 --- a/server/clip_server/model/transformer.py +++ b/server/clip_server/model/transformer.py @@ -36,7 +36,7 @@ def __init__( norm_layer: Callable = LayerNorm, scale_attn: bool = False, scale_fc: bool = False, - cast_dtype=torch.float32, + dtype='fp32', ): super().__init__() head_dim = d_model // n_head @@ -47,7 +47,7 @@ def __init__( MultiheadAttention(d_model, n_head) if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available() - and cast_dtype in (torch.float16, torch.bfloat16) + and dtype in ('fp16', 'bf16') and self.flash_attention else nn.MultiheadAttention(d_model, n_head) ) @@ -85,7 +85,7 @@ def __init__( mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - cast_dtype=torch.float32, + dtype='fp32', ): super().__init__() self.width = width @@ -100,7 +100,7 @@ def __init__( mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, - cast_dtype=cast_dtype, + dtype=dtype, ) for _ in range(layers) ] @@ -129,7 +129,7 @@ def __init__( output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - cast_dtype=torch.float16, + dtype='fp32', ): super().__init__() self.context_length = context_length @@ -147,7 +147,7 @@ def __init__( heads=heads, act_layer=act_layer, norm_layer=norm_layer, - cast_dtype=cast_dtype, + dtype=dtype, ) self.ln_final = norm_layer(width) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) @@ -216,7 +216,7 @@ def __init__( output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - cast_dtype=torch.float32, + dtype='fp32', ): super().__init__() self.image_size = to_2tuple(image_size) @@ -247,7 +247,7 @@ def __init__( mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, - cast_dtype=cast_dtype, + dtype=dtype, ) self.ln_post = norm_layer(width) diff --git a/tests/test_model.py b/tests/test_model.py index 053a1d2b0..79ee3d7fb 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -10,7 +10,7 @@ [ ('ViT-L/14@336px', OpenCLIPModel), ('RN50::openai', OpenCLIPModel), - ('roberta-ViT-B-32::laion2b-s12b-b32k', OpenCLIPModel), + # ('roberta-ViT-B-32::laion2b-s12b-b32k', OpenCLIPModel), ('M-CLIP/LABSE-Vit-L-14', MultilingualCLIPModel), ], ) From b1fc4d6479826d480aa3390ff844a490ae2c0e1f Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Fri, 18 Nov 2022 18:43:03 +0800 Subject: [PATCH 20/33] fix: housekeeping --- server/clip_server/model/model.py | 46 +++++++++++++++---------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 1028e48dd..d097066f5 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -27,6 +27,7 @@ from open_clip.model import CustomTextCLIP as _CustomTextCLIP from open_clip.model import CLIP as _CLIP + class ModifiedResNet(_ModifiedResNet): def forward(self, x): # To handle fp16 inference @@ -111,9 +112,7 @@ def _build_vision_tower( else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = ( - LayerNormFp32 - if dtype in (torch.float16, torch.bfloat16) - else LayerNorm + LayerNormFp32 if dtype in (torch.float16, torch.bfloat16) else LayerNorm ) visual = VisionTransformer( image_size=vision_cfg.image_size, @@ -150,9 +149,7 @@ def _build_text_tower( else: act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( - LayerNormFp32 - if dtype in (torch.float16, torch.bfloat16) - else LayerNorm + LayerNormFp32 if dtype in (torch.float16, torch.bfloat16) else LayerNorm ) text = TextTransformer( @@ -171,12 +168,12 @@ def _build_text_tower( class CustomTextCLIP(_CustomTextCLIP): def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - dtype: Optional[torch.dtype] = torch.float32, + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + dtype: Optional[torch.dtype] = torch.float32, ): super().__init__(embed_dim, vision_cfg, text_cfg, quick_gelu, dtype) self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, dtype) @@ -185,12 +182,12 @@ def __init__( class CLIP(_CLIP): def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - dtype: Optional[torch.dtype] = torch.float32, + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + dtype: Optional[torch.dtype] = torch.float32, ): super().__init__(embed_dim, vision_cfg, text_cfg, quick_gelu, dtype) self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, dtype) @@ -206,6 +203,7 @@ def __init__( def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): """Convert applicable model parameters to low-precision (bf16 or fp16)""" + def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.to(dtype) @@ -359,7 +357,9 @@ def load_openai_model( A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if dtype is None: - dtype = torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16 + dtype = ( + torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16 + ) try: # loading JIT archive model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() @@ -470,7 +470,9 @@ def load_openclip_model( pretrained_image: bool = False, ): if dtype is None: - dtype = torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16 + dtype = ( + torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16 + ) model_name = model_name.replace( '/', '-' @@ -510,9 +512,7 @@ def load_openclip_model( model.to(device=device) if dtype in (torch.float16, torch.bfloat16): - convert_weights_to_lp( - model, dtype=dtype - ) + convert_weights_to_lp(model, dtype=dtype) if jit: model = torch.jit.script(model) From 5478afc4912afb7000faf53b5d88726c4205db8c Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Mon, 21 Nov 2022 00:17:41 +0800 Subject: [PATCH 21/33] fix: typo --- server/clip_server/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index d097066f5..f22542be0 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -247,7 +247,7 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, - dtype=torch.float32, + dtype=torch.float16, ): vit = "visual.proj" in state_dict @@ -423,7 +423,7 @@ def patch_device(module): patch_device(model.encode_text) # patch dtype to float32 (typically for CPU) - if dtype == 'fp32': + if dtype == torch.float32: float_holder = torch.jit.trace( lambda: torch.ones([]).float(), example_inputs=[] ) From ba3aa44faae504ebbb7b77e7f59f878e0beeabb1 Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Mon, 21 Nov 2022 15:10:26 +0800 Subject: [PATCH 22/33] fix: refactor --- server/clip_server/model/transformer.py | 289 ++++++------------------ 1 file changed, 69 insertions(+), 220 deletions(-) diff --git a/server/clip_server/model/transformer.py b/server/clip_server/model/transformer.py index 14ab4f114..da81f91aa 100644 --- a/server/clip_server/model/transformer.py +++ b/server/clip_server/model/transformer.py @@ -11,11 +11,11 @@ import torch from torch import nn -from torch.nn import functional as F -from torch.utils.checkpoint import checkpoint - -from open_clip.utils import to_2tuple from open_clip.transformer import LayerNorm +from open_clip.transformer import ResidualAttentionBlock as _ResidualAttentionBlock +from open_clip.transformer import Transformer as _Transformer +from open_clip.transformer import VisionTransformer as _VisionTransformer +from open_clip.transformer import TextTransformer as _TextTransformer # Use flash attention try: @@ -26,71 +26,50 @@ FLASH_ATTENTION_AVAILABLE = False -class ResidualAttentionBlock(nn.Module): +class ResidualAttentionBlock(_ResidualAttentionBlock): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, + ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - scale_attn: bool = False, - scale_fc: bool = False, - dtype='fp32', + # scale_attn: bool = False, + # scale_fc: bool = False, + dtype=torch.float32, ): - super().__init__() + super().__init__( + d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer + ) head_dim = d_model // n_head - self.flash_attention = head_dim % 8 == 0 and head_dim <= 128 + flash_attention = head_dim % 8 == 0 and head_dim <= 128 - self.ln_1 = norm_layer(d_model) self.attn = ( MultiheadAttention(d_model, n_head) if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available() - and dtype in ('fp16', 'bf16') - and self.flash_attention + and dtype in (torch.float16, torch.bfloat16) + and flash_attention else nn.MultiheadAttention(d_model, n_head) ) - self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity() - - self.ln_2 = norm_layer(d_model) - mlp_width = int(d_model * mlp_ratio) - self.mlp = nn.Sequential( - OrderedDict( - [ - ("c_fc", nn.Linear(d_model, mlp_width)), - ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)), - ] - ) - ) - - def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ln_attn(self.attention(self.ln_1(x), attn_mask=attn_mask)) - x = x + self.mlp(self.ln_2(x)) - return x -class Transformer(nn.Module): +class Transformer(_Transformer): def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, + ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - dtype='fp32', + dtype=torch.float32, ): - super().__init__() - self.width = width - self.layers = layers - self.grad_checkpointing = False + super().__init__( + width, layers, heads, mlp_ratio, ls_init_value, act_layer, norm_layer + ) self.resblocks = nn.ModuleList( [ @@ -98,6 +77,7 @@ def __init__( width, heads, mlp_ratio, + ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, dtype=dtype, @@ -106,210 +86,79 @@ def __init__( ] ) - def get_cast_dtype(self) -> torch.dtype: - return self.resblocks[0].mlp.c_fc.weight.dtype - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - for r in self.resblocks: - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - return x - -class TextTransformer(nn.Module): +class VisionTransformer(_VisionTransformer): def __init__( self, - context_length: int = 77, - vocab_size: int = 49408, - width: int = 512, - heads: int = 8, - layers: int = 12, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - dtype='fp32', + dtype=torch.float32, ): - super().__init__() - self.context_length = context_length - self.vocab_size = vocab_size - self.width = width - self.output_dim = output_dim - - self.token_embedding = nn.Embedding(vocab_size, width) - self.positional_embedding = nn.Parameter( - torch.empty(self.context_length, width) + super().__init__( + image_size, + patch_size, + width, + layers, + heads, + mlp_ratio, + ls_init_value, + output_dim, + act_layer, + norm_layer, ) self.transformer = Transformer( - width=width, - layers=layers, - heads=heads, + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, dtype=dtype, ) - self.ln_final = norm_layer(width) - self.text_projection = nn.Parameter(torch.empty(width, output_dim)) - - self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) - - self.init_parameters() - - def init_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - - proj_std = (self.transformer.width**-0.5) * ( - (2 * self.transformer.layers) ** -0.5 - ) - attn_std = self.transformer.width**-0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.transformer.grad_checkpointing = enable - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - def forward(self, text): - cast_dtype = self.transformer.get_cast_dtype() - x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.to(cast_dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x - - -class VisionTransformer(nn.Module): +class TextTransformer(_TextTransformer): def __init__( self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - dtype='fp32', + dtype=torch.float32, ): - super().__init__() - self.image_size = to_2tuple(image_size) - self.patch_size = to_2tuple(patch_size) - self.grid_size = ( - self.image_size[0] // self.patch_size[0], - self.image_size[1] // self.patch_size[1], - ) - self.output_dim = output_dim - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False, + super().__init__( + context_length, + vocab_size, + width, + heads, + layers, + ls_init_value, + output_dim, + act_layer, + norm_layer, ) - scale = width**-0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter( - scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width) - ) - self.ln_pre = norm_layer(width) self.transformer = Transformer( - width, - layers, - heads, - mlp_ratio, + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, dtype=dtype, ) - - self.ln_post = norm_layer(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - self.init_parameters() - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert ( - unlocked_groups == 0 - ), 'partial locking not currently supported for this model' - for param in self.parameters(): - param.requires_grad = False - - def init_parameters(self): - # FIXME OpenAI CLIP did not define an init for the VisualTransformer - # TODO experiment if default PyTorch init, below, or alternate init is best. - - # nn.init.normal_(self.class_embedding, std=self.scale) - # nn.init.normal_(self.positional_embedding, std=self.scale) - # - # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - # attn_std = self.transformer.width ** -0.5 - # fc_std = (2 * self.transformer.width) ** -0.5 - # for block in self.transformer.resblocks: - # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - # - # if self.text_projection is not None: - # nn.init.normal_(self.text_projection, std=self.scale) - pass - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.transformer.grad_checkpointing = enable - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [ - self.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), - x, - ], - dim=1, - ) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - x = x @ self.proj - - return x From 2b7af82e77aefaac8e8c258f154d91c70f01be36 Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Mon, 21 Nov 2022 15:37:45 +0800 Subject: [PATCH 23/33] fix: refactor --- server/clip_server/model/model.py | 157 +++++++++++- server/clip_server/model/transformer.py | 328 ++++++++++++------------ 2 files changed, 318 insertions(+), 167 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index f22542be0..2b757d1ed 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -9,24 +9,37 @@ """ import warnings -import numpy as np +from typing import Callable from dataclasses import dataclass from typing import Tuple, Union, Optional from copy import deepcopy import torch from torch import nn -import torch.nn.functional as F -from clip_server.model.transformer import TextTransformer, VisionTransformer +# import torch.nn.functional as F + +# from clip_server.model.transformer import TextTransformer, VisionTransformer from open_clip.transformer import QuickGELU, LayerNorm, LayerNormFp32, Attention from open_clip.timm_model import TimmModel from open_clip.factory import _MODEL_CONFIGS from open_clip.hf_model import PreTrainedTextEncoder +from open_clip.transformer import ResidualAttentionBlock as _ResidualAttentionBlock +from open_clip.transformer import Transformer as _Transformer +from open_clip.transformer import VisionTransformer as _VisionTransformer +from open_clip.transformer import TextTransformer as _TextTransformer from open_clip.modified_resnet import ModifiedResNet as _ModifiedResNet from open_clip.model import CustomTextCLIP as _CustomTextCLIP from open_clip.model import CLIP as _CLIP +# Use flash attention +try: + from clip_server.model.flash_attention import MultiheadAttention + + FLASH_ATTENTION_AVAILABLE = True +except: + FLASH_ATTENTION_AVAILABLE = False + class ModifiedResNet(_ModifiedResNet): def forward(self, x): @@ -35,6 +48,144 @@ def forward(self, x): return super().forward(x) +class ResidualAttentionBlock(_ResidualAttentionBlock): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + # scale_attn: bool = False, + # scale_fc: bool = False, + dtype=torch.float32, + ): + super().__init__( + d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer + ) + head_dim = d_model // n_head + flash_attention = head_dim % 8 == 0 and head_dim <= 128 + + self.attn = ( + MultiheadAttention(d_model, n_head) + if FLASH_ATTENTION_AVAILABLE + and torch.cuda.is_available() + and dtype in (torch.float16, torch.bfloat16) + and flash_attention + else nn.MultiheadAttention(d_model, n_head) + ) + + +class Transformer(_Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + dtype=torch.float32, + ): + super().__init__( + width, layers, heads, mlp_ratio, ls_init_value, act_layer, norm_layer + ) + + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + dtype=dtype, + ) + for _ in range(layers) + ] + ) + + +class VisionTransformer(_VisionTransformer): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + dtype=torch.float32, + ): + super().__init__( + image_size, + patch_size, + width, + layers, + heads, + mlp_ratio, + ls_init_value, + output_dim, + act_layer, + norm_layer, + ) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + dtype=dtype, + ) + + +class TextTransformer(_TextTransformer): + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + dtype=torch.float32, + ): + super().__init__( + context_length, + vocab_size, + width, + heads, + layers, + ls_init_value, + output_dim, + act_layer, + norm_layer, + ) + + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + dtype=dtype, + ) + self.init_parameters() + + @dataclass class CLIPVisionCfg: layers: Union[Tuple[int, int, int, int], int] = 12 diff --git a/server/clip_server/model/transformer.py b/server/clip_server/model/transformer.py index da81f91aa..1ae166484 100644 --- a/server/clip_server/model/transformer.py +++ b/server/clip_server/model/transformer.py @@ -1,164 +1,164 @@ -# Originally from https://github.com/mlfoundations/open_clip. -# -# Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, -# Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, -# John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, -# Ludwig Schmidt - - -from collections import OrderedDict -from typing import Callable, Optional - -import torch -from torch import nn -from open_clip.transformer import LayerNorm -from open_clip.transformer import ResidualAttentionBlock as _ResidualAttentionBlock -from open_clip.transformer import Transformer as _Transformer -from open_clip.transformer import VisionTransformer as _VisionTransformer -from open_clip.transformer import TextTransformer as _TextTransformer - -# Use flash attention -try: - from clip_server.model.flash_attention import MultiheadAttention - - FLASH_ATTENTION_AVAILABLE = True -except: - FLASH_ATTENTION_AVAILABLE = False - - -class ResidualAttentionBlock(_ResidualAttentionBlock): - def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - # scale_attn: bool = False, - # scale_fc: bool = False, - dtype=torch.float32, - ): - super().__init__( - d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer - ) - head_dim = d_model // n_head - flash_attention = head_dim % 8 == 0 and head_dim <= 128 - - self.attn = ( - MultiheadAttention(d_model, n_head) - if FLASH_ATTENTION_AVAILABLE - and torch.cuda.is_available() - and dtype in (torch.float16, torch.bfloat16) - and flash_attention - else nn.MultiheadAttention(d_model, n_head) - ) - - -class Transformer(_Transformer): - def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - dtype=torch.float32, - ): - super().__init__( - width, layers, heads, mlp_ratio, ls_init_value, act_layer, norm_layer - ) - - self.resblocks = nn.ModuleList( - [ - ResidualAttentionBlock( - width, - heads, - mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - dtype=dtype, - ) - for _ in range(layers) - ] - ) - - -class VisionTransformer(_VisionTransformer): - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - ls_init_value: float = None, - output_dim: int = 512, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - dtype=torch.float32, - ): - super().__init__( - image_size, - patch_size, - width, - layers, - heads, - mlp_ratio, - ls_init_value, - output_dim, - act_layer, - norm_layer, - ) - self.transformer = Transformer( - width, - layers, - heads, - mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - dtype=dtype, - ) - - -class TextTransformer(_TextTransformer): - def __init__( - self, - context_length: int = 77, - vocab_size: int = 49408, - width: int = 512, - heads: int = 8, - layers: int = 12, - ls_init_value: float = None, - output_dim: int = 512, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - dtype=torch.float32, - ): - super().__init__( - context_length, - vocab_size, - width, - heads, - layers, - ls_init_value, - output_dim, - act_layer, - norm_layer, - ) - - self.transformer = Transformer( - width=width, - layers=layers, - heads=heads, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - dtype=dtype, - ) - self.init_parameters() +# # Originally from https://github.com/mlfoundations/open_clip. +# # +# # Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, +# # Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, +# # John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, +# # Ludwig Schmidt +# +# +# from collections import OrderedDict +# from typing import Callable, Optional +# +# import torch +# from torch import nn +# from open_clip.transformer import LayerNorm +# from open_clip.transformer import ResidualAttentionBlock as _ResidualAttentionBlock +# from open_clip.transformer import Transformer as _Transformer +# from open_clip.transformer import VisionTransformer as _VisionTransformer +# from open_clip.transformer import TextTransformer as _TextTransformer +# +# # Use flash attention +# try: +# from clip_server.model.flash_attention import MultiheadAttention +# +# FLASH_ATTENTION_AVAILABLE = True +# except: +# FLASH_ATTENTION_AVAILABLE = False +# +# +# class ResidualAttentionBlock(_ResidualAttentionBlock): +# def __init__( +# self, +# d_model: int, +# n_head: int, +# mlp_ratio: float = 4.0, +# ls_init_value: float = None, +# act_layer: Callable = nn.GELU, +# norm_layer: Callable = LayerNorm, +# # scale_attn: bool = False, +# # scale_fc: bool = False, +# dtype=torch.float32, +# ): +# super().__init__( +# d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer +# ) +# head_dim = d_model // n_head +# flash_attention = head_dim % 8 == 0 and head_dim <= 128 +# +# self.attn = ( +# MultiheadAttention(d_model, n_head) +# if FLASH_ATTENTION_AVAILABLE +# and torch.cuda.is_available() +# and dtype in (torch.float16, torch.bfloat16) +# and flash_attention +# else nn.MultiheadAttention(d_model, n_head) +# ) +# +# +# class Transformer(_Transformer): +# def __init__( +# self, +# width: int, +# layers: int, +# heads: int, +# mlp_ratio: float = 4.0, +# ls_init_value: float = None, +# act_layer: Callable = nn.GELU, +# norm_layer: Callable = LayerNorm, +# dtype=torch.float32, +# ): +# super().__init__( +# width, layers, heads, mlp_ratio, ls_init_value, act_layer, norm_layer +# ) +# +# self.resblocks = nn.ModuleList( +# [ +# ResidualAttentionBlock( +# width, +# heads, +# mlp_ratio, +# ls_init_value=ls_init_value, +# act_layer=act_layer, +# norm_layer=norm_layer, +# dtype=dtype, +# ) +# for _ in range(layers) +# ] +# ) +# +# +# class VisionTransformer(_VisionTransformer): +# def __init__( +# self, +# image_size: int, +# patch_size: int, +# width: int, +# layers: int, +# heads: int, +# mlp_ratio: float, +# ls_init_value: float = None, +# output_dim: int = 512, +# act_layer: Callable = nn.GELU, +# norm_layer: Callable = LayerNorm, +# dtype=torch.float32, +# ): +# super().__init__( +# image_size, +# patch_size, +# width, +# layers, +# heads, +# mlp_ratio, +# ls_init_value, +# output_dim, +# act_layer, +# norm_layer, +# ) +# self.transformer = Transformer( +# width, +# layers, +# heads, +# mlp_ratio, +# ls_init_value=ls_init_value, +# act_layer=act_layer, +# norm_layer=norm_layer, +# dtype=dtype, +# ) +# +# +# class TextTransformer(_TextTransformer): +# def __init__( +# self, +# context_length: int = 77, +# vocab_size: int = 49408, +# width: int = 512, +# heads: int = 8, +# layers: int = 12, +# ls_init_value: float = None, +# output_dim: int = 512, +# act_layer: Callable = nn.GELU, +# norm_layer: Callable = LayerNorm, +# dtype=torch.float32, +# ): +# super().__init__( +# context_length, +# vocab_size, +# width, +# heads, +# layers, +# ls_init_value, +# output_dim, +# act_layer, +# norm_layer, +# ) +# +# self.transformer = Transformer( +# width=width, +# layers=layers, +# heads=heads, +# ls_init_value=ls_init_value, +# act_layer=act_layer, +# norm_layer=norm_layer, +# dtype=dtype, +# ) +# self.init_parameters() From 06ad96af1e6265bd6cd8cc94015744cd5f5a4efb Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Mon, 21 Nov 2022 15:43:38 +0800 Subject: [PATCH 24/33] fix: refactor --- server/clip_server/model/model.py | 11 +- server/clip_server/model/transformer.py | 164 ------------------------ 2 files changed, 2 insertions(+), 173 deletions(-) delete mode 100644 server/clip_server/model/transformer.py diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 2b757d1ed..b2052cea2 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -9,17 +9,12 @@ """ import warnings +import torch +from torch import nn from typing import Callable from dataclasses import dataclass from typing import Tuple, Union, Optional from copy import deepcopy - -import torch -from torch import nn - -# import torch.nn.functional as F - -# from clip_server.model.transformer import TextTransformer, VisionTransformer from open_clip.transformer import QuickGELU, LayerNorm, LayerNormFp32, Attention from open_clip.timm_model import TimmModel from open_clip.factory import _MODEL_CONFIGS @@ -57,8 +52,6 @@ def __init__( ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - # scale_attn: bool = False, - # scale_fc: bool = False, dtype=torch.float32, ): super().__init__( diff --git a/server/clip_server/model/transformer.py b/server/clip_server/model/transformer.py deleted file mode 100644 index 1ae166484..000000000 --- a/server/clip_server/model/transformer.py +++ /dev/null @@ -1,164 +0,0 @@ -# # Originally from https://github.com/mlfoundations/open_clip. -# # -# # Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, -# # Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, -# # John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, -# # Ludwig Schmidt -# -# -# from collections import OrderedDict -# from typing import Callable, Optional -# -# import torch -# from torch import nn -# from open_clip.transformer import LayerNorm -# from open_clip.transformer import ResidualAttentionBlock as _ResidualAttentionBlock -# from open_clip.transformer import Transformer as _Transformer -# from open_clip.transformer import VisionTransformer as _VisionTransformer -# from open_clip.transformer import TextTransformer as _TextTransformer -# -# # Use flash attention -# try: -# from clip_server.model.flash_attention import MultiheadAttention -# -# FLASH_ATTENTION_AVAILABLE = True -# except: -# FLASH_ATTENTION_AVAILABLE = False -# -# -# class ResidualAttentionBlock(_ResidualAttentionBlock): -# def __init__( -# self, -# d_model: int, -# n_head: int, -# mlp_ratio: float = 4.0, -# ls_init_value: float = None, -# act_layer: Callable = nn.GELU, -# norm_layer: Callable = LayerNorm, -# # scale_attn: bool = False, -# # scale_fc: bool = False, -# dtype=torch.float32, -# ): -# super().__init__( -# d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer -# ) -# head_dim = d_model // n_head -# flash_attention = head_dim % 8 == 0 and head_dim <= 128 -# -# self.attn = ( -# MultiheadAttention(d_model, n_head) -# if FLASH_ATTENTION_AVAILABLE -# and torch.cuda.is_available() -# and dtype in (torch.float16, torch.bfloat16) -# and flash_attention -# else nn.MultiheadAttention(d_model, n_head) -# ) -# -# -# class Transformer(_Transformer): -# def __init__( -# self, -# width: int, -# layers: int, -# heads: int, -# mlp_ratio: float = 4.0, -# ls_init_value: float = None, -# act_layer: Callable = nn.GELU, -# norm_layer: Callable = LayerNorm, -# dtype=torch.float32, -# ): -# super().__init__( -# width, layers, heads, mlp_ratio, ls_init_value, act_layer, norm_layer -# ) -# -# self.resblocks = nn.ModuleList( -# [ -# ResidualAttentionBlock( -# width, -# heads, -# mlp_ratio, -# ls_init_value=ls_init_value, -# act_layer=act_layer, -# norm_layer=norm_layer, -# dtype=dtype, -# ) -# for _ in range(layers) -# ] -# ) -# -# -# class VisionTransformer(_VisionTransformer): -# def __init__( -# self, -# image_size: int, -# patch_size: int, -# width: int, -# layers: int, -# heads: int, -# mlp_ratio: float, -# ls_init_value: float = None, -# output_dim: int = 512, -# act_layer: Callable = nn.GELU, -# norm_layer: Callable = LayerNorm, -# dtype=torch.float32, -# ): -# super().__init__( -# image_size, -# patch_size, -# width, -# layers, -# heads, -# mlp_ratio, -# ls_init_value, -# output_dim, -# act_layer, -# norm_layer, -# ) -# self.transformer = Transformer( -# width, -# layers, -# heads, -# mlp_ratio, -# ls_init_value=ls_init_value, -# act_layer=act_layer, -# norm_layer=norm_layer, -# dtype=dtype, -# ) -# -# -# class TextTransformer(_TextTransformer): -# def __init__( -# self, -# context_length: int = 77, -# vocab_size: int = 49408, -# width: int = 512, -# heads: int = 8, -# layers: int = 12, -# ls_init_value: float = None, -# output_dim: int = 512, -# act_layer: Callable = nn.GELU, -# norm_layer: Callable = LayerNorm, -# dtype=torch.float32, -# ): -# super().__init__( -# context_length, -# vocab_size, -# width, -# heads, -# layers, -# ls_init_value, -# output_dim, -# act_layer, -# norm_layer, -# ) -# -# self.transformer = Transformer( -# width=width, -# layers=layers, -# heads=heads, -# ls_init_value=ls_init_value, -# act_layer=act_layer, -# norm_layer=norm_layer, -# dtype=dtype, -# ) -# self.init_parameters() From 8ef2090b3cb7ba24f64e0ee5d5f319f59ed757d0 Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Mon, 21 Nov 2022 16:01:56 +0800 Subject: [PATCH 25/33] fix: type hint --- server/clip_server/model/model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index b2052cea2..3fda1ef29 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -52,7 +52,7 @@ def __init__( ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - dtype=torch.float32, + dtype: torch.dtype = torch.float32, ): super().__init__( d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer @@ -80,7 +80,7 @@ def __init__( ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - dtype=torch.float32, + dtype: torch.dtype = torch.float32, ): super().__init__( width, layers, heads, mlp_ratio, ls_init_value, act_layer, norm_layer @@ -115,7 +115,7 @@ def __init__( output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - dtype=torch.float32, + dtype: torch.dtype = torch.float32, ): super().__init__( image_size, @@ -153,7 +153,7 @@ def __init__( output_dim: int = 512, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, - dtype=torch.float32, + dtype: torch.dtype = torch.float32, ): super().__init__( context_length, @@ -391,7 +391,7 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, - dtype=torch.float16, + dtype: torch.dtype = torch.float16, ): vit = "visual.proj" in state_dict From 8a36b8aee08e48eab337572e265fc089f7eb01d1 Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Mon, 21 Nov 2022 18:57:13 +0800 Subject: [PATCH 26/33] fix: address comments --- server/clip_server/model/flash_attention.py | 6 - server/clip_server/model/model.py | 153 ++++-------------- server/clip_server/model/openclip_model.py | 4 +- server/clip_server/model/pretrained_models.py | 11 +- tests/test_model.py | 2 +- 5 files changed, 46 insertions(+), 130 deletions(-) diff --git a/server/clip_server/model/flash_attention.py b/server/clip_server/model/flash_attention.py index a5ed6c13e..bf244ac5f 100644 --- a/server/clip_server/model/flash_attention.py +++ b/server/clip_server/model/flash_attention.py @@ -57,12 +57,6 @@ def attention( key_padding_mask: a bool tensor of shape (B, S) """ - assert not need_weights, "not allowed to return weights." - assert q.dtype in [ - torch.float16, - torch.bfloat16, - ], f"flash attention only support torch.float16 or torch.bfloat16 but got {q.dtype}." - assert q.is_cuda, "flash attention only support cuda." if cu_seqlens is None: max_s = seqlen diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 3fda1ef29..b43b9a66b 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -11,7 +11,6 @@ import warnings import torch from torch import nn -from typing import Callable from dataclasses import dataclass from typing import Tuple, Union, Optional from copy import deepcopy @@ -45,18 +44,9 @@ def forward(self, x): class ResidualAttentionBlock(_ResidualAttentionBlock): def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - dtype: torch.dtype = torch.float32, + self, d_model: int, n_head: int, dtype: torch.dtype = torch.float32, **kwargs ): - super().__init__( - d_model, n_head, mlp_ratio, ls_init_value, act_layer, norm_layer - ) + super().__init__(d_model, n_head, **kwargs) head_dim = d_model // n_head flash_attention = head_dim % 8 == 0 and head_dim <= 128 @@ -71,111 +61,23 @@ def __init__( class Transformer(_Transformer): - def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - dtype: torch.dtype = torch.float32, - ): - super().__init__( - width, layers, heads, mlp_ratio, ls_init_value, act_layer, norm_layer - ) - + def __init__(self, layers: int, dtype: torch.dtype = torch.float32, **kwargs): + super().__init__(layers=layers, **kwargs) self.resblocks = nn.ModuleList( - [ - ResidualAttentionBlock( - width, - heads, - mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - dtype=dtype, - ) - for _ in range(layers) - ] + [ResidualAttentionBlock(dtype=dtype, **kwargs) for _ in range(layers)] ) class VisionTransformer(_VisionTransformer): - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - ls_init_value: float = None, - output_dim: int = 512, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - dtype: torch.dtype = torch.float32, - ): - super().__init__( - image_size, - patch_size, - width, - layers, - heads, - mlp_ratio, - ls_init_value, - output_dim, - act_layer, - norm_layer, - ) - self.transformer = Transformer( - width, - layers, - heads, - mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - dtype=dtype, - ) + def __init__(self, dtype: torch.dtype = torch.float32, **kwargs): + super().__init__(**kwargs) + self.transformer = Transformer(dtype=dtype, **kwargs) class TextTransformer(_TextTransformer): - def __init__( - self, - context_length: int = 77, - vocab_size: int = 49408, - width: int = 512, - heads: int = 8, - layers: int = 12, - ls_init_value: float = None, - output_dim: int = 512, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - dtype: torch.dtype = torch.float32, - ): - super().__init__( - context_length, - vocab_size, - width, - heads, - layers, - ls_init_value, - output_dim, - act_layer, - norm_layer, - ) - - self.transformer = Transformer( - width=width, - layers=layers, - heads=heads, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - dtype=dtype, - ) + def __init__(self, dtype: torch.dtype = torch.float32, **kwargs): + super().__init__(**kwargs) + self.transformer = Transformer(dtype=dtype, **kwargs) self.init_parameters() @@ -233,7 +135,7 @@ def _build_vision_tower( if vision_cfg.timm_model_name: visual = TimmModel( - vision_cfg.timm_model_name, + model_name=vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, @@ -320,8 +222,15 @@ def __init__( dtype: Optional[torch.dtype] = torch.float32, ): super().__init__(embed_dim, vision_cfg, text_cfg, quick_gelu, dtype) - self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, dtype) - self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, dtype) + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + dtype=dtype, + ) + self.text = _build_text_tower( + embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, dtype=dtype + ) class CLIP(_CLIP): @@ -333,9 +242,15 @@ def __init__( quick_gelu: bool = False, dtype: Optional[torch.dtype] = torch.float32, ): - super().__init__(embed_dim, vision_cfg, text_cfg, quick_gelu, dtype) - self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, dtype) - text = _build_text_tower(embed_dim, text_cfg, quick_gelu, dtype) + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + dtype=dtype, + ) + text = _build_text_tower( + embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, dtype=dtype + ) self.transformer = text.transformer self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding @@ -390,7 +305,7 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): def build_model_from_openai_state_dict( state_dict: dict, - quick_gelu=True, + quick_gelu: bool = False, dtype: torch.dtype = torch.float16, ): vit = "visual.proj" in state_dict @@ -459,7 +374,7 @@ def build_model_from_openai_state_dict( layers=transformer_layers, ) model = CLIP( - embed_dim, + embed_dim=embed_dim, vision_cfg=vision_cfg, text_cfg=text_cfg, quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU @@ -479,7 +394,7 @@ def load_openai_model( model_path: str, device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu', dtype: Optional[Union[str, torch.dtype]] = None, - jit=True, + jit: bool = True, ): """Load a CLIP model @@ -608,10 +523,10 @@ def load_openclip_model( model_path: str, device: Union[str, torch.device] = 'cpu', jit: bool = False, - dtype: Optional[Union[str, torch.dtype]] = None, force_quick_gelu: bool = False, force_custom_text: bool = False, pretrained_image: bool = False, + dtype: Optional[Union[str, torch.dtype]] = None, ): if dtype is None: dtype = ( diff --git a/server/clip_server/model/openclip_model.py b/server/clip_server/model/openclip_model.py index 5d15be9f2..2091be33d 100644 --- a/server/clip_server/model/openclip_model.py +++ b/server/clip_server/model/openclip_model.py @@ -36,11 +36,11 @@ def __init__( if pretrained == 'openai': self._model = load_openai_model( - model_path, device=device, jit=jit, dtype=dtype + model_path=model_path, device=device, jit=jit, dtype=dtype ) else: self._model = load_openclip_model( - self._model_name, + model_name=self._model_name, model_path=model_path, device=device, jit=jit, diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index 272f9bdaa..fb11b13de 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -71,13 +71,17 @@ '3bf99353f6f1829faac0bb155be4382a', ), 'roberta-ViT-B-32::laion2b-s12b-b32k': ( - 'ViT-B-32-laion2b-s12b-b32k.bin', + 'roberta-ViT-B-32-laion2b-s12b-b32k.bin', '76d4c9d13774cc15fa0e2b1b94a8402c', ), 'xlm-roberta-base-ViT-B-32::laion5b-s13b-b90k': ( - 'ViT-B-32::laion5b-s13b-b90k.bin', + 'xlm-roberta-base-ViT-B-32-laion5b-s13b-b90k.bin', 'f68abc07ef349720f1f880180803142d', ), + 'xlm-roberta-large-ViT-H-14::frozen_laion5b_s13b_b90k': ( + 'xlm-roberta-large-ViT-H-14-frozen_laion5b_s13b_b90k.bin', + 'b49991239a419d704fdba59c42d5536d', + ), # older version name format 'RN50': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), 'RN101': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'), @@ -104,12 +108,15 @@ 'RN50x16': 384, 'RN50x64': 448, 'ViT-B-32': 224, + 'roberta-ViT-B-32': 224, + 'xlm-roberta-base-ViT-B-32': 224, 'ViT-B-16': 224, 'Vit-B-16Plus': 240, 'ViT-B-16-plus-240': 240, 'ViT-L-14': 224, 'ViT-L-14-336': 336, 'ViT-H-14': 224, + 'xlm-roberta-large-ViT-H-14': 224, 'ViT-g-14': 224, } diff --git a/tests/test_model.py b/tests/test_model.py index 79ee3d7fb..e6f8fa7c9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -10,7 +10,7 @@ [ ('ViT-L/14@336px', OpenCLIPModel), ('RN50::openai', OpenCLIPModel), - # ('roberta-ViT-B-32::laion2b-s12b-b32k', OpenCLIPModel), + ('xlm-roberta-large-ViT-H-14::frozen_laion5b_s13b_b90k', OpenCLIPModel), ('M-CLIP/LABSE-Vit-L-14', MultilingualCLIPModel), ], ) From cf9595df6d3d2707dfc5861aa0e062227243cfc5 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Mon, 21 Nov 2022 20:24:23 +0800 Subject: [PATCH 27/33] fix: visiontransformer --- server/clip_server/model/model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index b43b9a66b..95c6701fe 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -69,8 +69,15 @@ def __init__(self, layers: int, dtype: torch.dtype = torch.float32, **kwargs): class VisionTransformer(_VisionTransformer): - def __init__(self, dtype: torch.dtype = torch.float32, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + image_size: int, + patch_size: int, + output_dim: int, + dtype: torch.dtype = torch.float32, + **kwargs, + ): + super().__init__(image_size, patch_size, output_dim=output_dim, **kwargs) self.transformer = Transformer(dtype=dtype, **kwargs) From e43990bdc9705f7c34e8f99a881da2bdbeb779ff Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Mon, 21 Nov 2022 20:35:32 +0800 Subject: [PATCH 28/33] fix: d_model and n_head --- server/clip_server/model/model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 95c6701fe..3e8171bdb 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -44,19 +44,19 @@ def forward(self, x): class ResidualAttentionBlock(_ResidualAttentionBlock): def __init__( - self, d_model: int, n_head: int, dtype: torch.dtype = torch.float32, **kwargs + self, width: int, heads: int, dtype: torch.dtype = torch.float32, **kwargs ): - super().__init__(d_model, n_head, **kwargs) - head_dim = d_model // n_head + super().__init__(width, heads, **kwargs) + head_dim = width // heads flash_attention = head_dim % 8 == 0 and head_dim <= 128 self.attn = ( - MultiheadAttention(d_model, n_head) + MultiheadAttention(width, heads) if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available() and dtype in (torch.float16, torch.bfloat16) and flash_attention - else nn.MultiheadAttention(d_model, n_head) + else nn.MultiheadAttention(width, heads) ) From 1bf83adb833e835f86f6a9df429a41b9e6ca3aac Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Mon, 21 Nov 2022 21:29:10 +0800 Subject: [PATCH 29/33] fix: texttransformer --- server/clip_server/model/model.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 3e8171bdb..29cd1c493 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -82,8 +82,15 @@ def __init__( class TextTransformer(_TextTransformer): - def __init__(self, dtype: torch.dtype = torch.float32, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + context_length: int, + vocab_size: int, + output_dim: int, + dtype: torch.dtype = torch.float32, + **kwargs, + ): + super().__init__(context_length, vocab_size, output_dim=output_dim, **kwargs) self.transformer = Transformer(dtype=dtype, **kwargs) self.init_parameters() @@ -249,22 +256,7 @@ def __init__( quick_gelu: bool = False, dtype: Optional[torch.dtype] = torch.float32, ): - self.visual = _build_vision_tower( - embed_dim=embed_dim, - vision_cfg=vision_cfg, - quick_gelu=quick_gelu, - dtype=dtype, - ) - text = _build_text_tower( - embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, dtype=dtype - ) - self.transformer = text.transformer - self.vocab_size = text.vocab_size - self.token_embedding = text.token_embedding - self.positional_embedding = text.positional_embedding - self.ln_final = text.ln_final - self.text_projection = text.text_projection - self.register_buffer('attn_mask', text.attn_mask, persistent=False) + super().__init__(embed_dim, vision_cfg, text_cfg, quick_gelu, dtype) def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): From f918d65779b314ce97bf4fb0b9943c3824f62906 Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Mon, 21 Nov 2022 22:25:42 +0800 Subject: [PATCH 30/33] fix: change model to fix oom gh action --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index e6f8fa7c9..053a1d2b0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -10,7 +10,7 @@ [ ('ViT-L/14@336px', OpenCLIPModel), ('RN50::openai', OpenCLIPModel), - ('xlm-roberta-large-ViT-H-14::frozen_laion5b_s13b_b90k', OpenCLIPModel), + ('roberta-ViT-B-32::laion2b-s12b-b32k', OpenCLIPModel), ('M-CLIP/LABSE-Vit-L-14', MultilingualCLIPModel), ], ) From d910ee8b95885894fc71070c22710944804453be Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Mon, 21 Nov 2022 23:53:48 +0800 Subject: [PATCH 31/33] fix: class init and param name --- server/clip_server/model/model.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 29cd1c493..c3f8f2e0b 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -10,6 +10,7 @@ import warnings import torch +import numpy as np from torch import nn from dataclasses import dataclass from typing import Tuple, Union, Optional @@ -181,6 +182,7 @@ def _build_vision_tower( layers=vision_cfg.layers, heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, @@ -218,6 +220,7 @@ def _build_text_tower( width=text_cfg.width, heads=text_cfg.heads, layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, @@ -256,7 +259,17 @@ def __init__( quick_gelu: bool = False, dtype: Optional[torch.dtype] = torch.float32, ): - super().__init__(embed_dim, vision_cfg, text_cfg, quick_gelu, dtype) + nn.Module.__init__(self) + self.visual = _build_vision_tower(embed_dim=embed_dim, vision_cfg=vision_cfg, quick_gelu=quick_gelu, dtype=dtype) + text = _build_text_tower(embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, dtype=dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): From 423f7887aaa4bec3b9dd03282a13eec7b0bf903d Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Mon, 21 Nov 2022 23:59:41 +0800 Subject: [PATCH 32/33] fix: black --- server/clip_server/model/model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index c3f8f2e0b..e4cfe3c3a 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -260,8 +260,16 @@ def __init__( dtype: Optional[torch.dtype] = torch.float32, ): nn.Module.__init__(self) - self.visual = _build_vision_tower(embed_dim=embed_dim, vision_cfg=vision_cfg, quick_gelu=quick_gelu, dtype=dtype) - text = _build_text_tower(embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, dtype=dtype) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + dtype=dtype, + ) + text = _build_text_tower( + embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, dtype=dtype + ) self.transformer = text.transformer self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding From 04c130aae40364e856b09abd4e8172b13b4c333e Mon Sep 17 00:00:00 2001 From: ZiniuYu Date: Tue, 22 Nov 2022 00:23:53 +0800 Subject: [PATCH 33/33] chore: bump open-clip version --- server/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/setup.py b/server/setup.py index 56fceb645..32f9ca25c 100644 --- a/server/setup.py +++ b/server/setup.py @@ -48,7 +48,7 @@ 'torchvision', 'jina>=3.8.0', 'prometheus-client', - 'open_clip_torch>=2.5.0', + 'open_clip_torch>=2.7.0', ], extras_require={ 'onnx': [