Skip to content

Commit

Permalink
feat(server): allow client sending tensor document (#678)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Apr 11, 2022
1 parent fa42dc5 commit 8b800ee
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
- name: Test
id: test
run: |
pytest --suppress-no-test-exit-code --cov=clip_client --cov-report=xml \
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
-v -s -m "not gpu" ${{ matrix.test-path }}
echo "::set-output name=codecov_flag::cas"
timeout-minutes: 30
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ jobs:
- name: Test
id: test
run: |
pytest --suppress-no-test-exit-code --cov=clip_client --cov-report=xml \
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
-v -s -m "not gpu" ${{ matrix.test-path }}
echo "::set-output name=codecov_flag::cas"
timeout-minutes: 30
Expand Down
2 changes: 2 additions & 0 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
c.load_uri_to_blob()
self._return_plain = False
yield c
elif c.tensor is not None:
yield c
else:
raise TypeError(f'unsupported input type {c!r} {c.content_type}')
else:
Expand Down
14 changes: 9 additions & 5 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(
**kwargs
):
super().__init__(**kwargs)
self._preprocess = clip._transform(_SIZE[name])
self._preprocess_blob = clip._transform_blob(_SIZE[name])
self._preprocess_tensor = clip._transform_ndarray(_SIZE[name])
self._model = CLIPOnnxModel(name)
if pool_backend == 'thread':
self._pool = ThreadPool(processes=num_worker_preprocess)
Expand All @@ -46,10 +47,13 @@ def __init__(

def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray':
for d in da:
if not d.blob and d.uri:
# in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri
d.load_uri_to_blob()
d.tensor = self._preprocess(Image.open(io.BytesIO(d.blob)))
if d.tensor is not None:
d.tensor = self._preprocess_tensor(d.tensor)
else:
if not d.blob and d.uri:
# in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri
d.load_uri_to_blob()
d.tensor = self._preprocess_blob(d.blob)
da.tensors = da.tensors.cpu().numpy()
return da

Expand Down
15 changes: 10 additions & 5 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,23 @@ def __init__(
else:
self._device = device
self._minibatch_size = minibatch_size
self._model, self._preprocess = clip.load(name, device=self._device, jit=jit)
self._model, self._preprocess_blob, self._preprocess_tensor = clip.load(
name, device=self._device, jit=jit
)
if pool_backend == 'thread':
self._pool = ThreadPool(processes=num_worker_preprocess)
else:
self._pool = Pool(processes=num_worker_preprocess)

def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray':
for d in da:
if not d.blob and d.uri:
# in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri
d.load_uri_to_blob()
d.tensor = self._preprocess(Image.open(io.BytesIO(d.blob)))
if d.tensor is not None:
d.tensor = self._preprocess_tensor(d.tensor)
else:
if not d.blob and d.uri:
# in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri
d.load_uri_to_blob()
d.tensor = self._preprocess_blob(d.blob)
da.tensors = da.tensors.to(self._device)
return da

Expand Down
34 changes: 31 additions & 3 deletions server/clip_server/model/clip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Originally from https://github.com/openai/CLIP. MIT License, Copyright (c) 2021 OpenAI

import os
import io
import urllib
import warnings
from typing import Union, List
Expand Down Expand Up @@ -91,9 +92,14 @@ def _convert_image_to_rgb(image):
return image.convert('RGB')


def _transform(n_px):
def _blob2image(blob):
return Image.open(io.BytesIO(blob))


def _transform_blob(n_px):
return Compose(
[
_blob2image,
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
Expand All @@ -106,6 +112,20 @@ def _transform(n_px):
)


def _transform_ndarray(n_px):
return Compose(
[
ToTensor(),
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)


def available_models() -> List[str]:
'''Returns the names of available CLIP models'''
return list(_MODELS.keys())
Expand Down Expand Up @@ -170,7 +190,11 @@ def load(
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == 'cpu':
model.float()
return model, _transform(model.visual.input_resolution)
return (
model,
_transform_blob(model.visual.input_resolution),
_transform_ndarray(model.visual.input_resolution),
)

# patch the device names
device_holder = torch.jit.trace(
Expand Down Expand Up @@ -235,7 +259,11 @@ def patch_float(module):

model.float()

return model, _transform(model.input_resolution.item())
return (
model,
_transform_blob(model.input_resolution.item()),
_transform_ndarray(model.input_resolution.item()),
)


def tokenize(
Expand Down
24 changes: 24 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

import pytest
from clip_server.model.clip import _transform_ndarray, _transform_blob
from docarray import Document


@pytest.mark.parametrize(
'image_uri',
[
f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg',
'https://docarray.jina.ai/_static/favicon.png',
],
)
@pytest.mark.parametrize('size', [224, 288, 384, 448])
def test_server_preprocess_ndarray_image(image_uri, size):
d1 = Document(uri=image_uri)
d1.load_uri_to_blob()
d2 = Document(uri=image_uri)
d2.load_uri_to_image_tensor()

t1 = _transform_blob(size)(d1.blob).numpy()
t2 = _transform_ndarray(size)(d2.tensor).numpy()
assert t1.shape == t2.shape
3 changes: 3 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def test_plain_inputs(make_flow, inputs, port_generator):
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
),
Document(text='hello, world'),
Document(
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
).load_uri_to_image_tensor(),
]
),
DocumentArray.from_files(
Expand Down

0 comments on commit 8b800ee

Please sign in to comment.