Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support B/32, L/14, H/14, and g/14 trained on LAION-2B #825

Merged
merged 6 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/user-guides/server.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Please also note that **different models give different sizes of output dimensio
| ViT-B-32::laion2b_e16 | ✅ | ✅ | ✅ | 512 | 577 | 2.93 | 1.40 |
| ViT-B-32::laion400m_e31 | ✅ | ✅ | ✅ | 512 | 577 | 2.93 | 1.40 |
| ViT-B-32::laion400m_e32 | ✅ | ✅ | ✅ | 512 | 577 | 2.94 | 1.40 |
| ViT-B-32::laion2B-s34B-b79K | ✅ | ✅ | ❌ | 512 | 577 | 2.94 | 1.40 |
| ViT-B-16::openai | ✅ | ✅ | ✅ | 512 | 335 | 3.20 | 1.44 |
| ViT-B-16::laion400m_e31 | ✅ | ✅ | ✅ | 512 | 571 | 2.93 | 1.44 |
| ViT-B-16::laion400m_e32 | ✅ | ✅ | ✅ | 512 | 571 | 2.94 | 1.44 |
Expand All @@ -87,7 +88,10 @@ Please also note that **different models give different sizes of output dimensio
| ViT-L-14::openai | ✅ | ✅ | ❌ | 768 | 890 | 3.66 | 2.04 |
| ViT-L-14::laion400m_e31 | ✅ | ✅ | ❌ | 768 | 1631 | 3.43 | 2.03 |
| ViT-L-14::laion400m_e32 | ✅ | ✅ | ❌ | 768 | 1631 | 3.42 | 2.03 |
| ViT-L-14::laion2B-s32B-b82K | ✅ | ✅ | ❌ | 768 | 1631 | 3.43 | 2.03 |
| ViT-L-14-336::openai | ✅ | ✅ | ❌ | 768 | 891 | 3.74 | 2.23 |
| ViT-H-14::laion2B-s32B-b79K | ✅ | 🚧 | ❌ | 1024 | 3762 | 4.45 | 3.26 |
| ViT-g-14::laion2B-s12B-b42K | ✅ | 🚧 | ❌ | 1024 | 5214 | 5.16 | 4.00 |
| M-CLIP/XLM-Roberta-Large-Vit-B-32 | ✅ | 🚧 | 🚧 | 512 | 4284 | 5.37 | 1.68 |
| M-CLIP/XLM-Roberta-Large-Vit-L-14 | ✅ | 🚧 | ❌ | 768 | 4293 | 4.30 | 4.97 |
| M-CLIP/XLM-Roberta-Large-Vit-B-16Plus | ✅ | 🚧 | 🚧 | 640 | 4293 | 4.30 | 4.13 |
Expand Down Expand Up @@ -192,7 +196,6 @@ Basically, each YAML file defines a [Jina Flow](https://docs.jina.ai/fundamental
Looking at the YAML file again, we can put it into three subsections as below:



````{tab} CLIP model config

```{code-block} yaml
Expand Down
8 changes: 8 additions & 0 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
('ViT-B-32-laion400m_e32/textual.onnx', '93284915937ba42a2b52ae8d3e5283a0'),
('ViT-B-32-laion400m_e32/visual.onnx', 'db220821a31fe9795fd8c2ba419078c5'),
),
'ViT-B-32::laion2B-s34B-b79K': (
('ViT-B-32-laion2B-s34B-b79K/textual.onnx', '84af5ae53da56464c76e67fe50fddbe9'),
('ViT-B-32-laion2B-s34B-b79K/visual.onnx', 'a2d4cbd1cf2632cd09ffce9b40bfd8bd'),
),
'ViT-B-16::openai': (
('ViT-B-16/textual.onnx', '6f0976629a446f95c0c8767658f12ebe'),
('ViT-B-16/visual.onnx', 'd5c03bfeef1abbd9bede54a8f6e1eaad'),
Expand Down Expand Up @@ -105,6 +109,10 @@
('ViT-L-14-laion400m_e32/textual.onnx', '8ba5b76ba71992923470c0261b10a67c'),
('ViT-L-14-laion400m_e32/visual.onnx', '49db3ba92bd816001e932530ad92d76c'),
),
'ViT-L-14::laion2B-s32B-b82K': (
('ViT-L-14-laion2B-s32B-b82K/textual.onnx', 'da36a6cbed4f56abf576fdea8b6fe2ee'),
('ViT-L-14-laion2B-s32B-b82K/visual.onnx', '1e337a190abba6a8650237dfae4740b7'),
),
'ViT-L-14-336::openai': (
('ViT-L-14@336px/textual.onnx', '78fab479f136403eed0db46f3e9e7ed2'),
('ViT-L-14@336px/visual.onnx', 'f3b1f5d55ca08d43d749e11f7e4ba27e'),
Expand Down
37 changes: 22 additions & 15 deletions server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def __init__(self, inplanes, planes, stride=1):
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.act1 = nn.ReLU(inplace=True)

self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.act2 = nn.ReLU(inplace=True)

self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.act3 = nn.ReLU(inplace=True)

self.downsample = None
self.stride = stride
Expand Down Expand Up @@ -88,16 +88,16 @@ def __init__(self, inplanes, planes, stride=1):
def forward(self, x: torch.Tensor):
identity = x

out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu3(out)
out = self.act3(out)
return out


Expand Down Expand Up @@ -166,15 +166,15 @@ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)

# residual layers
Expand Down Expand Up @@ -226,9 +226,9 @@ def set_grad_checkpointing(self, enable=True):
pass

def stem(self, x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x

Expand Down Expand Up @@ -273,30 +273,37 @@ def __init__(
n_head: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
scale_cosine_attn: bool = False,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity()

self.ln_2 = LayerNorm(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, mlp_width)),
('ln', LayerNorm(mlp_width) if scale_fc else nn.Identity()),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)

def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]

def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
if attn_mask is not None:
attn_mask = attn_mask.to(dtype=x.dtype, device=x.device)
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
x = x + self.ln_attn(self.attention(self.ln_1(x), attn_mask=attn_mask))
x = x + self.mlp(self.ln_2(x))
return x

Expand Down
20 changes: 19 additions & 1 deletion server/clip_server/model/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
'ViT-B-32-laion400m_e32.pt',
'359e0dba4a419f175599ee0c63a110d8',
),
'ViT-B-32::laion2B-s34B-b79K': (
'ViT-B-32-laion2B-s34B-b79K.bin',
'2fc036aea9cd7306f5ce7ce6abb8d0bf',
),
'ViT-B-16::openai': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'),
'ViT-B-16::laion400m_e31': (
'ViT-B-16-laion400m_e31.pt',
Expand All @@ -53,7 +57,19 @@
'ViT-L-14-laion400m_e32.pt',
'a76cde1bc744ca38c6036b920c847a89',
),
'ViT-L-14::laion2B-s32B-b82K': (
'ViT-L-14-laion2B-s32B-b82K.bin',
'4d2275fc7b2d7ee9db174f9b57ddecbd',
),
'ViT-L-14-336::openai': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'),
'ViT-H-14::laion2B-s32B-b79K': (
'ViT-H-14-laion2B-s32B-b79K.bin',
'2aa6c46521b165a0daeb8cdc6668c7d3',
),
'ViT-g-14::laion2B-s12B-b42K': (
'ViT-g-14-laion2B-s12B-b42K.bin',
'3bf99353f6f1829faac0bb155be4382a',
),
# older version name format
'RN50': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'),
'RN101': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'),
Expand Down Expand Up @@ -81,10 +97,12 @@
'RN50x64': 448,
'ViT-B-32': 224,
'ViT-B-16': 224,
'Vit-B-16Plus': 240,
'ViT-B-16-plus-240': 240,
'ViT-L-14': 224,
'ViT-L-14-336': 336,
'Vit-B-16Plus': 240,
'ViT-H-14': 224,
'ViT-g-14': 224,
}


Expand Down
1 change: 1 addition & 0 deletions server/clip_server/onnx-flow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ executors:
metas:
py_modules:
- clip_server.executors.clip_onnx
timeout_ready: 3000000
replicas: 1
1 change: 1 addition & 0 deletions server/clip_server/torch-flow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ executors:
metas:
py_modules:
- clip_server.executors.clip_torch
timeout_ready: 3000000
replicas: 1