Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: do not send blob from server when it is loaded in client #804

Merged
merged 9 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def encode(
...

def encode(self, content, **kwargs):
from docarray import Document

if isinstance(content, str):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead'
Expand All @@ -119,6 +121,11 @@ def encode(self, content, **kwargs):
**self._get_post_payload(content, kwargs),
on_done=partial(self._gather_result, results=results),
)

for c in content:
if isinstance(c, Document) and c.tags.pop('__loaded_by_CAS__', False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(c, Document) and c.tags.pop('__loaded_by_CAS__', False):
if hasattr(c, 'tags') and c.tags.pop('__loaded_by_CAS__', False):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a broken change

Copy link
Member

@numb3r3 numb3r3 Aug 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benefits from change above:

  1. reduce the memory footprint (to address OOM when processing a large number of documents);
  2. keep the inputs not updated;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The future better way is to support in-place update in encode(docs)

c.pop('blob')

return self._unboxed_result(results)

def _gather_result(self, response, results: 'DocumentArray'):
Expand Down Expand Up @@ -160,7 +167,8 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
_mime = mimetypes.guess_type(c)[0]
if _mime and _mime.startswith('image'):
yield Document(
tags={'__created_by_CAS__': True}, uri=c
tags={'__created_by_CAS__': True, '__loaded_by_CAS__': True},
uri=c,
).load_uri_to_blob()
else:
yield Document(tags={'__created_by_CAS__': True}, text=c)
Expand All @@ -169,6 +177,7 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
yield c
elif not c.blob and c.uri:
c.load_uri_to_blob()
c.tags['__loaded_by_CAS__'] = True
yield c
elif c.tensor is not None:
yield c
Expand Down Expand Up @@ -331,6 +340,7 @@ def _prepare_single_doc(d: 'Document'):
return d
elif not d.blob and d.uri:
d.load_uri_to_blob()
d.tags['__loaded_by_CAS__'] = True
return d
elif d.tensor is not None:
return d
Expand All @@ -346,6 +356,18 @@ def _prepare_rank_doc(d: 'Document', _source: str = 'matches'):
setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)])
return d

@staticmethod
def _reset_rank_doc(d: 'Document', _source: str = 'matches'):
_get = lambda d: getattr(d, _source)

if d.tags.pop('__loaded_by_CAS__', False):
d.pop('blob')

for c in _get(d):
if c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')
return d

def _iter_rank_docs(
self, content, _source='matches'
) -> Generator['Document', None, None]:
Expand Down Expand Up @@ -408,6 +430,9 @@ def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
**self._get_rank_payload(docs, kwargs),
on_done=partial(self._gather_result, results=results),
)
for d in docs:
self._reset_rank_doc(d, _source=kwargs.get('source', 'matches'))

return results

async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
Expand Down
5 changes: 4 additions & 1 deletion server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def preproc_image(
tensors_batch.append(preprocess_fn(d.tensor).detach())

# recover doc content
d.content = content
if d.tags.pop('__loaded_by_CAS__', False):
d.pop('tensor')
else:
d.content = content

tensors_batch = torch.stack(tensors_batch).type(torch.float32)

Expand Down
18 changes: 18 additions & 0 deletions tests/test_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ async def test_torch_executor_rank_img2texts(encoder_class):
for d in da:
for c in d.matches:
assert c.scores['clip_score'].value is not None
assert '__loaded_by_CAS__' not in c.tags
assert not c.tensor
assert not c.blob
org_score = d.matches[:, 'scores__clip_score__value']
assert org_score == list(sorted(org_score, reverse=True))
assert '__loaded_by_CAS__' not in d.tags
assert not d.tensor
assert not d.blob


@pytest.mark.asyncio
Expand All @@ -53,9 +59,15 @@ async def test_torch_executor_rank_text2imgs(encoder_class):
for c in d.matches:
assert c.scores['clip_score'].value is not None
assert c.scores['clip_score_cosine'].value is not None
assert '__loaded_by_CAS__' not in c.tags
assert not c.tensor
assert not c.blob
np.testing.assert_almost_equal(
sum(c.scores['clip_score'].value for c in d.matches), 1
)
assert '__loaded_by_CAS__' not in d.tags
assert not d.tensor
assert not d.blob


@pytest.mark.parametrize(
Expand All @@ -79,6 +91,12 @@ async def test_torch_executor_rank_text2imgs(encoder_class):
def test_docarray_inputs(make_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
r = c.rank([d])
assert '__loaded_by_CAS__' not in d.tags
assert not d.blob
assert not d.tensor
assert '__loaded_by_CAS__' not in d.matches[0].tags
assert not d.matches[0].blob
assert not d.matches[0].tensor
assert isinstance(r, DocumentArray)
rv1 = r['@m', 'scores__clip_score__value']
rv2 = r['@m', 'scores__clip_score_cosine__value']
Expand Down
16 changes: 16 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def test_docarray_inputs(make_flow, inputs, port_generator):
assert isinstance(r, DocumentArray)
assert r.embeddings.shape
assert '__created_by_CAS__' not in r[0].tags
assert '__loaded_by_CAS__' not in r[0].tags
assert not r[0].tensor
assert not r[0].blob


@pytest.mark.parametrize(
Expand All @@ -104,6 +107,9 @@ def test_docarray_preserve_original_inputs(make_flow, inputs, port_generator):
assert r.embeddings.shape
assert r.contents == inputs.contents
assert '__created_by_CAS__' not in r[0].tags
assert '__loaded_by_CAS__' not in r[0].tags
assert not r[0].tensor
assert not r[0].blob


@pytest.mark.parametrize(
Expand Down Expand Up @@ -134,5 +140,15 @@ def test_docarray_traversal(make_flow, inputs, port_generator):
r2 = c.post(on='/', inputs=da, parameters={'access_paths': '@c'})
assert r1[0].chunks.embeddings.shape[0] == len(inputs)
assert '__created_by_CAS__' not in r1[0].tags
assert '__loaded_by_CAS__' not in r1[0].tags
assert not r1[0].tensor
assert not r1[0].blob
assert not r1[0].chunks[0].tensor
assert not r1[0].chunks[0].blob
assert r2[0].chunks.embeddings.shape[0] == len(inputs)
assert '__created_by_CAS__' not in r2[0].tags
assert '__loaded_by_CAS__' not in r2[0].tags
assert not r2[0].tensor
assert not r2[0].blob
assert not r2[0].chunks[0].tensor
assert not r2[0].chunks[0].blob