diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 28adc7296..77a55e22b 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -19,7 +19,7 @@ from open_clip.transformer import QuickGELU, LayerNorm, LayerNormFp32, Attention from open_clip.timm_model import TimmModel from open_clip.factory import _MODEL_CONFIGS -from open_clip.hf_model import PreTrainedTextEncoder +from open_clip.hf_model import HFTextEncoder from open_clip.transformer import ResidualAttentionBlock as _ResidualAttentionBlock from open_clip.transformer import Transformer as _Transformer from open_clip.transformer import VisionTransformer as _VisionTransformer @@ -75,11 +75,18 @@ def __init__( self, image_size: int, patch_size: int, + global_average_pool: bool, output_dim: int, dtype: torch.dtype = torch.float32, **kwargs, ): - super().__init__(image_size, patch_size, output_dim=output_dim, **kwargs) + super().__init__( + image_size, + patch_size, + global_average_pool=global_average_pool, + output_dim=output_dim, + **kwargs, + ) self.transformer = Transformer(dtype=dtype, **kwargs) def forward(self, x: torch.Tensor): @@ -111,6 +118,7 @@ class CLIPVisionCfg: patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) timm_model_name: str = ( None # a valid model name overrides layers, width, patch_size ) @@ -136,6 +144,7 @@ class CLIPTextCfg: ls_init_value: Optional[float] = None # layer scale initial value hf_model_name: str = None hf_tokenizer_name: str = None + hf_model_pretrained: bool = True proj: str = 'mlp' pooler_type: str = 'mean_pooler' @@ -189,6 +198,7 @@ def _build_vision_tower( heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, + global_average_pool=vision_cfg.global_average_pool, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, @@ -208,11 +218,12 @@ def _build_text_tower( text_cfg = CLIPTextCfg(**text_cfg) if text_cfg.hf_model_name: - text = PreTrainedTextEncoder( + text = HFTextEncoder( text_cfg.hf_model_name, output_dim=embed_dim, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, + pretrained=text_cfg.hf_model_pretrained, ) else: act_layer = QuickGELU if quick_gelu else nn.GELU diff --git a/server/setup.py b/server/setup.py index 5d9e9511e..b75b014ac 100644 --- a/server/setup.py +++ b/server/setup.py @@ -47,7 +47,7 @@ 'torchvision<=0.13.0' if sys.version_info <= (3, 7, 2) else 'torchvision', 'jina>=3.12.0', 'prometheus-client', - 'open_clip_torch>=2.7.0', + 'open_clip_torch>=2.8.0', ], extras_require={ 'onnx': [ diff --git a/tests/test_server.py b/tests/test_server.py index d020ee12f..b5046baf0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -12,7 +12,7 @@ def test_server_download(tmpdir): download_model( url='https://docarray.jina.ai/_static/favicon.png', target_folder=tmpdir, - md5sum='a084999188f4290e2654aec43207ff2e', + md5sum='66ea4817d73514888dcf6c7d2b00016d', with_resume=False, ) target_path = os.path.join(tmpdir, 'favicon.png') @@ -29,14 +29,14 @@ def test_server_download(tmpdir): download_model( url='https://docarray.jina.ai/_static/favicon.png', target_folder=tmpdir, - md5sum='a084999188f4290e2654aec43207ff2e', + md5sum='66ea4817d73514888dcf6c7d2b00016d', with_resume=True, ) assert os.path.getsize(target_path) == file_size assert not os.path.exists(part_path) -@pytest.mark.parametrize('md5', ['ABC', None, 'a084999188f4290e2654aec43207ff2e']) +@pytest.mark.parametrize('md5', ['ABC', None, '66ea4817d73514888dcf6c7d2b00016d']) def test_server_download_md5(tmpdir, md5): if md5 != 'ABC': download_model(