Skip to content

Commit

Permalink
avoid changing output dimentionality for a single input (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
generall authored Mar 13, 2024
1 parent 041a606 commit 361f674
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
3 changes: 1 addition & 2 deletions fastembed/sparse/splade_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ def _post_process_onnx_output(cls, output: Tuple[np.ndarray, np.ndarray]) -> Ite

weighted_log = relu_log * np.expand_dims(attention_mask, axis=-1)

max_val = np.max(weighted_log, axis=1)
scores = np.max(weighted_log, axis=1)

# Score matrix of shape (batch_size, vocab_size)
# Most of the values are 0, only a few are non-zero
scores = np.squeeze(max_val)
for row_scores in scores:
indices = row_scores.nonzero()[0]
scores = row_scores[indices]
Expand Down
15 changes: 15 additions & 0 deletions tests/test_sparse_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,18 @@ def test_batch_embedding():

for i, value in enumerate(result.values):
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]


def test_single_embedding():
docs_to_embed = docs

for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
print("evaluating", model_name)
model = SparseTextEmbedding(model_name=model_name)
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
print(result.indices)

assert result.indices.tolist() == expected_result["indices"]

for i, value in enumerate(result.values):
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]

0 comments on commit 361f674

Please sign in to comment.