From 65032f02db30671f7a2a6ca78e371588ae98ab2b Mon Sep 17 00:00:00 2001 From: Ziniu Yu Date: Fri, 5 Aug 2022 14:36:52 +0800 Subject: [PATCH] feat: encode text first when both text and uri are presented (#795) * fix: encode text first when both text and uri are presented * fix: encode text first when both text and uri are presented * test: add split da test * fix: typo * test: test split da --- server/clip_server/executors/helper.py | 8 ++-- tests/test_helper.py | 60 ++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/server/clip_server/executors/helper.py b/server/clip_server/executors/helper.py index da24597b9..3c97a34de 100644 --- a/server/clip_server/executors/helper.py +++ b/server/clip_server/executors/helper.py @@ -73,12 +73,10 @@ def preproc_text( def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'): - if doc.uri: - img_da.append(doc) - elif doc.blob or (doc.tensor is not None): - img_da.append(doc) - elif doc.text: + if doc.text: txt_da.append(doc) + elif doc.blob or (doc.tensor is not None) or doc.uri: + img_da.append(doc) def set_rank(docs, _logit_scale=np.exp(4.60517)): diff --git a/tests/test_helper.py b/tests/test_helper.py index 5a61441c2..f7d79ac62 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -1,6 +1,8 @@ import pytest import numpy as np from clip_server.executors.helper import numpy_softmax +from clip_server.executors.helper import split_img_txt_da +from docarray import Document, DocumentArray @pytest.mark.parametrize('shape', [(5, 10), (5, 10, 10)]) @@ -17,3 +19,61 @@ def test_numpy_softmax(shape, axis): np_softmax = numpy_softmax(logits, axis=axis) torch_softmax = torch.from_numpy(logits).softmax(dim=axis).numpy() np.testing.assert_array_almost_equal(np_softmax, torch_softmax) + + +@pytest.mark.parametrize( + 'inputs', + [ + ( + DocumentArray( + [ + Document(text='hello, world'), + Document(text='goodbye, world'), + Document( + text='hello, world', + uri='https://docarray.jina.ai/_static/favicon.png', + ), + Document( + uri='https://docarray.jina.ai/_static/favicon.png', + ), + ] + ), + (3, 1), + ), + ( + DocumentArray( + [ + Document(text='hello, world'), + Document(tensor=np.array([0, 1, 2])), + Document( + uri='https://docarray.jina.ai/_static/favicon.png' + ).load_uri_to_blob(), + Document( + tensor=np.array([0, 1, 2]), + uri='https://docarray.jina.ai/_static/favicon.png', + ), + Document( + uri='https://docarray.jina.ai/_static/favicon.png', + ), + ] + ), + (1, 4), + ), + ( + DocumentArray( + [ + Document(text='hello, world'), + Document(uri='https://docarray.jina.ai/_static/favicon.png'), + ] + ), + (1, 1), + ), + ], +) +def test_split_img_txt_da(inputs): + txt_da = DocumentArray() + img_da = DocumentArray() + for doc in inputs[0]: + split_img_txt_da(doc, img_da, txt_da) + assert len(txt_da) == inputs[1][0] + assert len(img_da) == inputs[1][1]