From 17f17b3fe7f276e1b019cca8aa651bf7c818a928 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 21 Nov 2024 16:09:46 +0000 Subject: [PATCH] support for custom feature encoding/decoding (#7284) * support for custom feature encoding/decoding * Update src/datasets/features/features.py --------- Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/features/features.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 34622cd94d9..aac6bff343c 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1348,7 +1348,7 @@ def encode_nested_example(schema, obj, level=0): return list(obj) # Object with special encoding: # ClassLabel will convert from string to int, TranslationVariableLanguages does some checks - elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)): + elif hasattr(schema, "encode_example"): return schema.encode_example(obj) if obj is not None else None # Other object should be directly convertible to a native Arrow type (like Translation and Translation) return obj @@ -1399,10 +1399,9 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni else: return decode_nested_example([schema.feature], obj) # Object with special decoding: - elif isinstance(schema, (Audio, Image, Video)): + elif hasattr(schema, "decode_example") and getattr(schema, "decode", True): # we pass the token to read and decode files from private repositories in streaming mode - if obj is not None and schema.decode: - return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) + return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) if obj is not None else None return obj @@ -1629,7 +1628,9 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False elif isinstance(feature, Sequence): return require_decoding(feature.feature) else: - return hasattr(feature, "decode_example") and (feature.decode if not ignore_decode_attribute else True) + return hasattr(feature, "decode_example") and ( + getattr(feature, "decode", True) if not ignore_decode_attribute else True + ) def require_storage_cast(feature: FeatureType) -> bool: