Skip to content

Commit

Permalink
feat: add clip_hg executor (#740)
Browse files Browse the repository at this point in the history
* feat: add clip_hg

* feat: add clip_hg

* feat: add clip_hg

* fix: remove unused parameters

* fix: remove unused parameters

* chore: rearrange requirement

* fix: remove deepcopy of input doc

* fix: add hg-flow; rename hg

* fix: add huggingface installation in ci/cd

* fix: remove unused import

* fix: add unit tests for clip_hg
  • Loading branch information
ZiniuYu authored Jun 9, 2022
1 parent 130108c commit d675148
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 10 deletions.
1 change: 1 addition & 0 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
python -m pip install wheel
pip install --no-cache-dir "client/[test]"
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[huggingface]"
- name: Test
id: test
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ jobs:
python -m pip install wheel pytest pytest-cov
pip install --no-cache-dir "client/[test]"
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[huggingface]"
- name: Test
id: test
run: |
Expand Down
247 changes: 247 additions & 0 deletions server/clip_server/executors/clip_hg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import os
import warnings
from multiprocessing.pool import ThreadPool
from typing import Dict, Optional, Sequence
import numpy as np
import torch
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer
from clip_server.executors.helper import (
split_img_txt_da,
set_rank,
)
from clip_server.model import clip
from jina import Executor, requests, DocumentArray, monitor


class CLIPEncoder(Executor):
def __init__(
self,
pretrained_model_name_or_path: str = 'openai/clip-vit-base-patch32',
finetuned_checkpoint_path: Optional[str] = None,
base_feature_extractor: Optional[str] = None,
base_tokenizer_model: Optional[str] = None,
use_default_preprocessing: bool = True,
max_length: int = 77,
device: str = 'cpu',
overwrite_embeddings: bool = False,
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
*args,
**kwargs,
):
"""
:param pretrained_model_name_or_path: Can be either:
- A string, the model id of a pretrained CLIP model hosted
inside a model repo on huggingface.co, e.g.,
'openai/clip-vit-base-patch32'
- A path to a directory containing model weights saved, e.g.,
./my_model_directory/
:param finetuned_checkpoint_path: If set, the pretrained model weights will be replaced with weights
loading from the given checkpoint.
:param base_feature_extractor: Base feature extractor for images.
Defaults to ``pretrained_model_name_or_path`` if None.
:param base_tokenizer_model: Base tokenizer model.
Defaults to ``pretrained_model_name_or_path`` if None.
:param use_default_preprocessing: Whether to use the `base_feature_extractor`
on images (tensors) before encoding them. If you disable this, you must
ensure that the images you pass in have the correct format, see the
``encode`` method for details.
:param max_length: Max length argument for the tokenizer. All CLIP models
use 77 as the max length.
:param device: Pytorch device to put the model on, e.g. 'cpu', 'cuda',
'cuda:1'.
:param overwrite_embeddings: Whether to overwrite existing embeddings. By
default docs that have embeddings already are not processed. This value
can be overwritten if the same parameter is passed to the request.
:param num_worker_preprocess: Number of cpu processes used in preprocessing step.
:param minibatch_size: Default batch size for encoding, used if the
batch size is not passed as a parameter with the request.
"""
super().__init__(*args, **kwargs)
self.overwrite_embeddings = overwrite_embeddings
self._minibatch_size = minibatch_size
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.base_tokenizer_model = (
base_tokenizer_model or pretrained_model_name_or_path
)
self.use_default_preprocessing = use_default_preprocessing
self.base_feature_extractor = (
base_feature_extractor or pretrained_model_name_or_path
)
self.max_length = max_length

# self.device = device
if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device

if not self._device.startswith('cuda') and (
'OMP_NUM_THREADS' not in os.environ
and hasattr(self.runtime_args, 'replicas')
):
replicas = getattr(self.runtime_args, 'replicas', 1)
num_threads = max(1, torch.get_num_threads() // replicas)
if num_threads < 2:
warnings.warn(
f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in '
f'sub-optimal performance.'
)

# NOTE: make sure to set the threads right after the torch import,
# and `torch.set_num_threads` always take precedence over environment variables `OMP_NUM_THREADS`.
# For more details, please see https://pytorch.org/docs/stable/generated/torch.set_num_threads.html
# FIXME: This hack would harm the performance in K8S deployment.
torch.set_num_threads(max(num_threads, 1))
torch.set_num_interop_threads(1)

self.vision_preprocessor = CLIPFeatureExtractor.from_pretrained(
self.base_feature_extractor
)
self.tokenizer = CLIPTokenizer.from_pretrained(self.base_tokenizer_model)
self._model = CLIPModel.from_pretrained(self.pretrained_model_name_or_path)

if finetuned_checkpoint_path:
if finetuned_checkpoint_path.startswith(
'https://'
) or finetuned_checkpoint_path.startswith('http://'):
state_dict = torch.hub.load_state_dict_from_url(
finetuned_checkpoint_path, map_location='cpu', progress=True
)
else:
state_dict = torch.load(finetuned_checkpoint_path, map_location='cpu')
self._model.load_state_dict(state_dict)

self._model.eval().to(self._device)
self._pool = ThreadPool(processes=num_worker_preprocess)

@monitor(name='preprocess_images_seconds')
def _preproc_images(self, docs: 'DocumentArray'):
contents = docs.contents
tensors_batch = []
for d in docs:
if d.blob:
d.convert_blob_to_image_tensor()
elif d.uri:
d.load_uri_to_image_tensor()
tensors_batch.append(d.tensor)
if self.use_default_preprocessing:
docs.tensors = self._preprocess_images(tensors_batch)['pixel_values']
else:
docs.tensors = torch.tensor(
tensors_batch, dtype=torch.float32, device=self._device
)
return docs, contents

@monitor(name='encode_images_seconds')
def _encode_images(self, docs: DocumentArray):
docs.embeddings = (
self._model.get_image_features(docs.tensors)
.cpu()
.numpy()
.astype(np.float32)
)

@monitor(name='preprocess_texts_seconds')
def _preproc_texts(self, docs: 'DocumentArray'):
contents = docs.contents
docs.tensors = self._tokenize_texts(docs.texts)['input_ids']
return docs, contents

@monitor(name='encode_texts_seconds')
def _encode_texts(self, docs: 'DocumentArray'):
docs.embeddings = (
self._model.get_text_features(docs.tensors).cpu().numpy().astype(np.float32)
)

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
await self.encode(docs['@r,m'])

set_rank(docs)

@requests
async def encode(self, docs: DocumentArray, **kwargs):
"""
Encode all documents with `text` or image content using the corresponding CLIP
encoder. Store the embeddings in the `embedding` attribute. Documents with
existing embeddings are not processed unless `overwrite_embeddings` is set to
True.
:param docs: Documents sent to the encoder. The image docs must have
``tensor`` of the
shape ``Height x Width x 3``. By default, the input ``tensor`` must
be an ``ndarray`` with ``dtype=uint8`` or ``dtype=float32``.
If you set ``use_default_preprocessing=True`` when creating this encoder,
then the ``tensor`` arrays should have the shape ``[H, W, 3]``, and be in
the RGB color format with ``dtype=uint8``.
If you set ``use_default_preprocessing=False`` when creating this encoder,
then you need to ensure that the images you pass in are already
pre-processed. This means that they are all the same size (for batching) -
the CLIP model was trained on images of the size ``224 x 224``, and that
they are of the shape ``[3, H, W]`` with ``dtype=float32``. They should
also be normalized (values between 0 and 1).
"""
_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
split_img_txt_da(d, _img_da, _txt_da)

with torch.inference_mode():
# for image
if _img_da:
for minibatch, _contents in _img_da.map_batch(
self._preproc_images,
batch_size=self._minibatch_size,
pool=self._pool,
):

self._encode_images(minibatch)

# recover original content
try:
_ = iter(_contents)
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct
except TypeError:
pass

# for text
if _txt_da:
for minibatch, _contents in _txt_da.map_batch(
self._preproc_texts,
batch_size=self._minibatch_size,
pool=self._pool,
):
self._encode_texts(minibatch)

# recover original content
try:
_ = iter(_contents)
for _d, _ct in zip(minibatch, _contents):
_d.content = _ct
except TypeError:
pass

# drop tensors
if self.use_default_preprocessing:
docs.tensors = None
return docs

def _preprocess_images(self, images):
"""Preprocess images."""
x = self.vision_preprocessor(
images=images,
return_tensors='pt',
)
return {k: v.to(torch.device(self._device)) for k, v in x.items()}

def _tokenize_texts(self, texts: Sequence[str]):
"""Tokenize texts."""
x = self.tokenizer(
texts,
max_length=self.max_length,
padding='longest',
truncation=True,
return_tensors='pt',
)
return {k: v.to(self._device) for k, v in x.items()}
11 changes: 11 additions & 0 deletions server/clip_server/hg-flow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
jtype: Flow
version: '1'
with:
port: 51000
executors:
- name: clip_h
uses:
jtype: CLIPEncoder
metas:
py_modules:
- executors/clip_hg.py
14 changes: 7 additions & 7 deletions server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
Expand Down Expand Up @@ -267,7 +267,7 @@ def __init__(
bias=False,
)

scale = width ** -0.5
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
Expand Down Expand Up @@ -373,7 +373,7 @@ def initialize_parameters(self):

if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
std = self.visual.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
Expand All @@ -389,10 +389,10 @@ def initialize_parameters(self):
if name.endswith('bn3.weight'):
nn.init.zeros_(param)

proj_std = (self.transformer.width ** -0.5) * (
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers) ** -0.5
)
attn_std = self.transformer.width ** -0.5
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
Expand All @@ -401,7 +401,7 @@ def initialize_parameters(self):
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)

def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
Expand Down Expand Up @@ -514,7 +514,7 @@ def build_model(state_dict: dict):
)
vision_patch_size = None
assert (
output_width ** 2 + 1
output_width**2 + 1
== state_dict['visual.attnpool.positional_embedding'].shape[0]
)
image_resolution = output_width * 32
Expand Down
1 change: 1 addition & 0 deletions server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
]
+ (['onnxruntime-gpu>=1.8.0'] if sys.platform != 'darwin' else []),
'tensorrt': ['nvidia-tensorrt'],
'huggingface': ['transformers>=4.16.2'],
},
classifiers=[
'Development Status :: 5 - Production/Stable',
Expand Down
30 changes: 27 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ def random_port():
return random_port


@pytest.fixture(scope='session', params=['onnx', 'torch'])
@pytest.fixture(scope='session', params=['onnx', 'torch', 'hg'])
def make_flow(port_generator, request):
if request.param == 'onnx':
from clip_server.executors.clip_onnx import CLIPEncoder
else:
elif request.param == 'torch':
from clip_server.executors.clip_torch import CLIPEncoder
else:
from clip_server.executors.clip_hg import CLIPEncoder

f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder)
with f:
Expand All @@ -37,10 +39,32 @@ def make_torch_flow(port_generator, request):
yield f


@pytest.fixture(scope='session', params=['torch'])
@pytest.fixture(scope='session', params=['tensorrt'])
def make_trt_flow(port_generator, request):
from clip_server.executors.clip_tensorrt import CLIPEncoder

f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder)
with f:
yield f


@pytest.fixture(scope='session', params=['hg'])
def make_hg_flow(port_generator, request):
from clip_server.executors.clip_hg import CLIPEncoder

f = Flow(port=port_generator()).add(name=request.param, uses=CLIPEncoder)
with f:
yield f


@pytest.fixture(scope='session', params=['hg'])
def make_hg_flow_no_default(port_generator, request):
from clip_server.executors.clip_hg import CLIPEncoder

f = Flow(port=port_generator()).add(
name=request.param,
uses=CLIPEncoder,
uses_with={'use_default_preprocessing': False},
)
with f:
yield f
Loading

0 comments on commit d675148

Please sign in to comment.