Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: do not send blob from server when it is loaded in client #804

Merged
merged 9 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def encode(self, content, **kwargs):
**self._get_post_payload(content, kwargs),
on_done=partial(self._gather_result, results=results),
)

for c in content:
if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')

return self._unboxed_result(results)

def _gather_result(self, response, results: 'DocumentArray'):
Expand Down Expand Up @@ -160,7 +165,8 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
_mime = mimetypes.guess_type(c)[0]
if _mime and _mime.startswith('image'):
yield Document(
tags={'__created_by_CAS__': True}, uri=c
tags={'__created_by_CAS__': True, '__loaded_by_CAS__': True},
uri=c,
).load_uri_to_blob()
else:
yield Document(tags={'__created_by_CAS__': True}, text=c)
Expand All @@ -169,6 +175,7 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
yield c
elif not c.blob and c.uri:
c.load_uri_to_blob()
c.tags['__loaded_by_CAS__'] = True
yield c
elif c.tensor is not None:
yield c
Expand Down Expand Up @@ -301,6 +308,10 @@ async def aencode(self, content, **kwargs):
),
)

for c in content:
if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')

return self._unboxed_result(results)

def _prepare_streaming(self, disable, total):
Expand Down Expand Up @@ -331,6 +342,7 @@ def _prepare_single_doc(d: 'Document'):
return d
elif not d.blob and d.uri:
d.load_uri_to_blob()
d.tags['__loaded_by_CAS__'] = True
return d
elif d.tensor is not None:
return d
Expand All @@ -346,6 +358,18 @@ def _prepare_rank_doc(d: 'Document', _source: str = 'matches'):
setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)])
return d

@staticmethod
def _reset_rank_doc(d: 'Document', _source: str = 'matches'):
_get = lambda d: getattr(d, _source)

if d.tags.pop('__loaded_by_CAS__', False):
d.pop('blob')

for c in _get(d):
if c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')
return d

def _iter_rank_docs(
self, content, _source='matches'
) -> Generator['Document', None, None]:
Expand Down Expand Up @@ -408,6 +432,9 @@ def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
**self._get_rank_payload(docs, kwargs),
on_done=partial(self._gather_result, results=results),
)
for d in docs:
self._reset_rank_doc(d, _source=kwargs.get('source', 'matches'))

return results

async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
Expand Down Expand Up @@ -435,4 +462,7 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
),
)

for d in docs:
self._reset_rank_doc(d, _source=kwargs.get('source', 'matches'))

return results
5 changes: 4 additions & 1 deletion server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def preproc_image(
tensors_batch.append(preprocess_fn(d.tensor).detach())

# recover doc content
d.content = content
if d.tags.pop('__loaded_by_CAS__', False):
d.pop('tensor')
else:
d.content = content

tensors_batch = torch.stack(tensors_batch).type(torch.float32)

Expand Down
37 changes: 36 additions & 1 deletion tests/test_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio

import os
import pytest

from clip_client import Client
from docarray import Document, DocumentArray


async def another_heavylifting_job():
Expand All @@ -16,3 +17,37 @@ async def test_async_encode(make_flow):
t2 = asyncio.create_task(c.aencode(['hello world'] * 10))
await asyncio.gather(t1, t2)
assert t2.result().shape


@pytest.mark.parametrize(
'inputs',
[
DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]),
DocumentArray(
[
Document(
uri='https://docarray.jina.ai/_static/favicon.png',
text='hello, world',
),
]
),
DocumentArray.from_files(
f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
),
],
)
@pytest.mark.asyncio
async def test_async_docarray_preserve_original_inputs(
make_flow, inputs, port_generator
):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
t1 = asyncio.create_task(another_heavylifting_job())
t2 = asyncio.create_task(c.aencode(inputs if not callable(inputs) else inputs()))
await asyncio.gather(t1, t2)
assert isinstance(t2.result(), DocumentArray)
assert t2.result().embeddings.shape
assert t2.result().contents == inputs.contents
assert '__created_by_CAS__' not in t2.result()[0].tags
assert '__loaded_by_CAS__' not in t2.result()[0].tags
assert not t2.result()[0].tensor
assert not t2.result()[0].blob
30 changes: 30 additions & 0 deletions tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from clip_server.executors.helper import numpy_softmax
from clip_server.executors.helper import split_img_txt_da
from clip_server.executors.helper import preproc_image
from docarray import Document, DocumentArray


Expand Down Expand Up @@ -77,3 +78,32 @@ def test_split_img_txt_da(inputs):
split_img_txt_da(doc, img_da, txt_da)
assert len(txt_da) == inputs[1][0]
assert len(img_da) == inputs[1][1]


@pytest.mark.parametrize(
'inputs',
[
DocumentArray(
[
Document(
uri='https://docarray.jina.ai/_static/favicon.png',
tags={'__loaded_by_CAS__': True},
).load_uri_to_blob(),
Document(
uri='https://docarray.jina.ai/_static/favicon.png',
).load_uri_to_blob(),
]
)
],
)
def test_preproc_image(inputs):
from clip_server.model import clip

preprocess_fn = clip._transform_ndarray(224)
da, pixel_values = preproc_image(inputs, preprocess_fn)
assert len(da) == 2
assert not da[0].blob
assert da[1].blob
assert not da[0].tensor
assert not da[1].tensor
assert pixel_values.get('pixel_values') is not None
22 changes: 22 additions & 0 deletions tests/test_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ async def test_torch_executor_rank_img2texts(encoder_class):
for d in da:
for c in d.matches:
assert c.scores['clip_score'].value is not None
assert '__loaded_by_CAS__' not in c.tags
assert not c.tensor
assert not c.blob
org_score = d.matches[:, 'scores__clip_score__value']
assert org_score == list(sorted(org_score, reverse=True))
assert '__loaded_by_CAS__' not in d.tags
assert not d.tensor
assert not d.blob


@pytest.mark.asyncio
Expand All @@ -53,9 +59,15 @@ async def test_torch_executor_rank_text2imgs(encoder_class):
for c in d.matches:
assert c.scores['clip_score'].value is not None
assert c.scores['clip_score_cosine'].value is not None
assert '__loaded_by_CAS__' not in c.tags
assert not c.tensor
assert not c.blob
np.testing.assert_almost_equal(
sum(c.scores['clip_score'].value for c in d.matches), 1
)
assert '__loaded_by_CAS__' not in d.tags
assert not d.tensor
assert not d.blob


@pytest.mark.parametrize(
Expand All @@ -79,6 +91,14 @@ async def test_torch_executor_rank_text2imgs(encoder_class):
def test_docarray_inputs(make_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
r = c.rank([d])
assert r[0].content == d.content
assert r[0].matches.contents == d.matches.contents
assert '__loaded_by_CAS__' not in d.tags
assert not d.blob
assert not d.tensor
assert '__loaded_by_CAS__' not in d.matches[0].tags
assert not d.matches[0].blob
assert not d.matches[0].tensor
assert isinstance(r, DocumentArray)
rv1 = r['@m', 'scores__clip_score__value']
rv2 = r['@m', 'scores__clip_score_cosine__value']
Expand Down Expand Up @@ -112,6 +132,8 @@ async def test_async_arank(make_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
r = await c.arank([d])
assert isinstance(r, DocumentArray)
assert r[0].content == d.content
assert r[0].matches.contents == d.matches.contents
rv = r['@m', 'scores__clip_score__value']
for v in rv:
assert v is not None
Expand Down
16 changes: 16 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def test_docarray_inputs(make_flow, inputs, port_generator):
assert isinstance(r, DocumentArray)
assert r.embeddings.shape
assert '__created_by_CAS__' not in r[0].tags
assert '__loaded_by_CAS__' not in r[0].tags
assert not r[0].tensor
assert not r[0].blob


@pytest.mark.parametrize(
Expand All @@ -104,6 +107,9 @@ def test_docarray_preserve_original_inputs(make_flow, inputs, port_generator):
assert r.embeddings.shape
assert r.contents == inputs.contents
assert '__created_by_CAS__' not in r[0].tags
assert '__loaded_by_CAS__' not in r[0].tags
assert not r[0].tensor
assert not r[0].blob


@pytest.mark.parametrize(
Expand Down Expand Up @@ -134,5 +140,15 @@ def test_docarray_traversal(make_flow, inputs, port_generator):
r2 = c.post(on='/', inputs=da, parameters={'access_paths': '@c'})
assert r1[0].chunks.embeddings.shape[0] == len(inputs)
assert '__created_by_CAS__' not in r1[0].tags
assert '__loaded_by_CAS__' not in r1[0].tags
assert not r1[0].tensor
assert not r1[0].blob
assert not r1[0].chunks[0].tensor
assert not r1[0].chunks[0].blob
assert r2[0].chunks.embeddings.shape[0] == len(inputs)
assert '__created_by_CAS__' not in r2[0].tags
assert '__loaded_by_CAS__' not in r2[0].tags
assert not r2[0].tensor
assert not r2[0].blob
assert not r2[0].chunks[0].tensor
assert not r2[0].chunks[0].blob