Skip to content

Commit

Permalink
fix: texttransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
OrangeSodahub committed Nov 21, 2022
1 parent e43990b commit 1bf83ad
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,15 @@ def __init__(


class TextTransformer(_TextTransformer):
def __init__(self, dtype: torch.dtype = torch.float32, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
context_length: int,
vocab_size: int,
output_dim: int,
dtype: torch.dtype = torch.float32,
**kwargs,
):
super().__init__(context_length, vocab_size, output_dim=output_dim, **kwargs)
self.transformer = Transformer(dtype=dtype, **kwargs)
self.init_parameters()

Expand Down Expand Up @@ -249,22 +256,7 @@ def __init__(
quick_gelu: bool = False,
dtype: Optional[torch.dtype] = torch.float32,
):
self.visual = _build_vision_tower(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
quick_gelu=quick_gelu,
dtype=dtype,
)
text = _build_text_tower(
embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, dtype=dtype
)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
super().__init__(embed_dim, vision_cfg, text_cfg, quick_gelu, dtype)


def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
Expand Down

0 comments on commit 1bf83ad

Please sign in to comment.