From 361f674e473c5a66fca47d095174af6aef6ea896 Mon Sep 17 00:00:00 2001 From: Andrey Vasnetsov Date: Wed, 13 Mar 2024 19:08:51 +0100 Subject: [PATCH] avoid changing output dimentionality for a single input (#148) --- fastembed/sparse/splade_pp.py | 3 +-- tests/test_sparse_embeddings.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index d1301c34..730a533f 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -27,11 +27,10 @@ def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> Ite weighted_log = relu_log * np.expand_dims(attention_mask, axis=-1) - max_val = np.max(weighted_log, axis=1) + scores = np.max(weighted_log, axis=1) # Score matrix of shape (batch_size, vocab_size) # Most of the values are 0, only a few are non-zero - scores = np.squeeze(max_val) for row_scores in scores: indices = row_scores.nonzero()[0] scores = row_scores[indices] diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 36d1f7e3..6c4bb8a1 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -40,3 +40,18 @@ def test_batch_embedding(): for i, value in enumerate(result.values): assert pytest.approx(value, abs=0.001) == expected_result["values"][i] + + +def test_single_embedding(): + docs_to_embed = docs + + for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): + print("evaluating", model_name) + model = SparseTextEmbedding(model_name=model_name) + result = next(iter(model.embed(docs_to_embed, batch_size=6))) + print(result.indices) + + assert result.indices.tolist() == expected_result["indices"] + + for i, value in enumerate(result.values): + assert pytest.approx(value, abs=0.001) == expected_result["values"][i]