Skip to content

Commit

Permalink
fix: reset input
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiniuYu committed Aug 11, 2022
1 parent 3ecf61f commit 4d61e28
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 2 deletions.
22 changes: 22 additions & 0 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def encode(
...

def encode(self, content, **kwargs):
from docarray import Document

if isinstance(content, str):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead'
Expand All @@ -119,6 +121,11 @@ def encode(self, content, **kwargs):
**self._get_post_payload(content, kwargs),
on_done=partial(self._gather_result, results=results),
)

for c in content:
if isinstance(c, Document) and c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')

return self._unboxed_result(results)

def _gather_result(self, response, results: 'DocumentArray'):
Expand Down Expand Up @@ -349,6 +356,18 @@ def _prepare_rank_doc(d: 'Document', _source: str = 'matches'):
setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)])
return d

@staticmethod
def _reset_rank_doc(d: 'Document', _source: str = 'matches'):
_get = lambda d: getattr(d, _source)

if d.tags.pop('__loaded_by_CAS__', False):
d.pop('blob')

for c in _get(d):
if c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')
return d

def _iter_rank_docs(
self, content, _source='matches'
) -> Generator['Document', None, None]:
Expand Down Expand Up @@ -411,6 +430,9 @@ def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
**self._get_rank_payload(docs, kwargs),
on_done=partial(self._gather_result, results=results),
)
for d in docs:
self._reset_rank_doc(d, _source=kwargs.get('source', 'matches'))

return results

async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
Expand Down
2 changes: 1 addition & 1 deletion server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def preproc_image(
tensors_batch.append(preprocess_fn(d.tensor).detach())

# recover doc content
if d.tags.pop('__loaded_by_CAS__', None):
if d.tags.pop('__loaded_by_CAS__', False):
d.pop('tensor')
else:
d.content = content
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def random_port():
return random_port


@pytest.fixture(scope='session', params=['onnx', 'torch', 'onnx_custom'])
@pytest.fixture(scope='session', params=['onnx'])
def make_flow(port_generator, request):
if request.param != 'onnx_custom':
if request.param == 'onnx':
Expand Down
6 changes: 6 additions & 0 deletions tests/test_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ async def test_torch_executor_rank_text2imgs(encoder_class):
def test_docarray_inputs(make_flow, d):
c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
r = c.rank([d])
assert '__loaded_by_CAS__' not in d.tags
assert not d.blob
assert not d.tensor
assert '__loaded_by_CAS__' not in d.matches[0].tags
assert not d.matches[0].blob
assert not d.matches[0].tensor
assert isinstance(r, DocumentArray)
rv1 = r['@m', 'scores__clip_score__value']
rv2 = r['@m', 'scores__clip_score_cosine__value']
Expand Down
2 changes: 2 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def test_docarray_preserve_original_inputs(make_flow, inputs, port_generator):
assert r.contents == inputs.contents
assert '__created_by_CAS__' not in r[0].tags
assert '__loaded_by_CAS__' not in r[0].tags
assert not r[0].tensor
assert not r[0].blob


@pytest.mark.parametrize(
Expand Down

0 comments on commit 4d61e28

Please sign in to comment.