Skip to content

Commit

Permalink
adding tests for padding_collate
Browse files Browse the repository at this point in the history
  • Loading branch information
pstjohn committed Aug 2, 2024
1 parent 9741281 commit c03d718
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
3 changes: 3 additions & 0 deletions sub-packages/bionemo-llm/src/bionemo/llm/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def padding_collate_fn(
Returns:
A collated batch with the same dictionary input structure.
"""
for entry in batch:
if entry.keys() != padding_values.keys():
raise ValueError("All keys in inputs must match provided padding_values.")

def _pad(tensors, padding_value):
if max_length is not None:
Expand Down
29 changes: 28 additions & 1 deletion sub-packages/bionemo-llm/tests/bionemo/llm/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,36 @@
# limitations under the License.


import pytest
import torch

from bionemo.llm.data.collate import bert_padding_collate_fn
from bionemo.llm.data.collate import bert_padding_collate_fn, padding_collate_fn


def test_padding_collate_fn():
sample1 = {
"my_key": torch.tensor([1, 2, 3]),
}
sample2 = {
"my_key": torch.tensor([4, 5, 6, 7, 8]),
}
batch = [sample1, sample2]
collated_batch = padding_collate_fn(batch, padding_values={"my_key": -1})

assert torch.all(torch.eq(collated_batch["my_key"], torch.tensor([[1, 2, 3, -1, -1], [4, 5, 6, 7, 8]])))


def test_padding_collate_with_missing_key_raises():
sample1 = {
"my_key": torch.tensor([1, 2, 3]),
}
sample2 = {
"my_key": torch.tensor([4, 5, 6, 7, 8]),
"other_key": torch.tensor([1, 2, 3]),
}
batch = [sample1, sample2]
with pytest.raises(ValueError, match="All keys in inputs must match provided padding_values."):
padding_collate_fn(batch, padding_values={"my_key": -1, "other_key": -1})


def test_bert_padding_collate_fn():
Expand Down

0 comments on commit c03d718

Please sign in to comment.