Skip to content

Commit

Permalink
feat: add clip_hg
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiniuYu committed Jun 1, 2022
1 parent 626cbc3 commit b8dc095
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from torch import nn




class Bottleneck(nn.Module):
expansion = 4

Expand Down Expand Up @@ -76,7 +74,7 @@ def __init__(
):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
Expand Down Expand Up @@ -269,7 +267,7 @@ def __init__(
bias=False,
)

scale = width ** -0.5
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
Expand Down Expand Up @@ -375,7 +373,7 @@ def initialize_parameters(self):

if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
std = self.visual.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
Expand All @@ -391,10 +389,10 @@ def initialize_parameters(self):
if name.endswith('bn3.weight'):
nn.init.zeros_(param)

proj_std = (self.transformer.width ** -0.5) * (
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers) ** -0.5
)
attn_std = self.transformer.width ** -0.5
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
Expand All @@ -403,7 +401,7 @@ def initialize_parameters(self):
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)

def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
Expand Down Expand Up @@ -516,7 +514,7 @@ def build_model(state_dict: dict):
)
vision_patch_size = None
assert (
output_width ** 2 + 1
output_width**2 + 1
== state_dict['visual.attnpool.positional_embedding'].shape[0]
)
image_resolution = output_width * 32
Expand Down

0 comments on commit b8dc095

Please sign in to comment.