Skip to content

Commit

Permalink
Allow null values in dict columns
Browse files Browse the repository at this point in the history
  • Loading branch information
mariosasko committed Mar 19, 2024
1 parent 0b55ec5 commit 13222ce
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,7 +1937,7 @@ def encode_column(self, column, column_name: str):
`list[Any]`
"""
column = cast_to_python_objects(column)
return [encode_nested_example(self[column_name], obj) for obj in column]
return [encode_nested_example(self[column_name], obj, level=1) for obj in column]

def encode_batch(self, batch):
"""
Expand All @@ -1955,7 +1955,7 @@ def encode_batch(self, batch):
raise ValueError(f"Column mismatch between batch {set(batch)} and features {set(self)}")
for key, column in batch.items():
column = cast_to_python_objects(column)
encoded_batch[key] = [encode_nested_example(self[key], obj) for obj in column]
encoded_batch[key] = [encode_nested_example(self[key], obj, level=1) for obj in column]
return encoded_batch

def decode_example(self, example: dict, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None):
Expand Down
10 changes: 10 additions & 0 deletions tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,16 @@ def test_encode_batch_with_example_with_empty_first_elem():
assert encoded_batch == {"x": [[[0], [1]], [[], [1]]]}


def test_encode_column_dict_with_none():
features = Features(
{
"x": {"a": ClassLabel(names=["a", "b"]), "b": Value("int32")},
}
)
encoded_column = features.encode_column([{"a": "a", "b": 1}, None], "x")
assert encoded_column == [{"a": 0, "b": 1}, None]


@pytest.mark.parametrize(
"feature",
[
Expand Down

0 comments on commit 13222ce

Please sign in to comment.