From 6d00ac60673e17ae0bcb12b3789b995dfed92c3b Mon Sep 17 00:00:00 2001 From: Ravindra Lanka Date: Thu, 26 May 2022 14:21:28 -0700 Subject: [PATCH] Fix nullablity determination for the AVRO fixed type. --- .../ingestion/extractor/schema_util.py | 12 +++++++---- .../tests/unit/test_schema_util.py | 21 +++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/extractor/schema_util.py b/metadata-ingestion/src/datahub/ingestion/extractor/schema_util.py index ff021ec328bede..faa035dbfb8262 100644 --- a/metadata-ingestion/src/datahub/ingestion/extractor/schema_util.py +++ b/metadata-ingestion/src/datahub/ingestion/extractor/schema_util.py @@ -179,10 +179,14 @@ def _is_nullable(self, schema: avro.schema.Schema) -> bool: return self._is_nullable(schema.type) if isinstance(schema, avro.schema.UnionSchema): return any(self._is_nullable(sub_schema) for sub_schema in schema.schemas) - elif isinstance(schema, avro.schema.PrimitiveSchema): - return schema.type == AVRO_TYPE_NULL or schema.props.get("_nullable", False) - else: - return self.default_nullable + if ( + isinstance(schema, avro.schema.PrimitiveSchema) + and schema.type == AVRO_TYPE_NULL + ): + return True + if isinstance(schema.props, dict): + return schema.props.get("_nullable", self.default_nullable) + return self.default_nullable def _get_cur_field_path(self) -> str: return ".".join(self._prefix_name_stack) diff --git a/metadata-ingestion/tests/unit/test_schema_util.py b/metadata-ingestion/tests/unit/test_schema_util.py index 554ed23f6a7856..bb417e64569ab7 100644 --- a/metadata-ingestion/tests/unit/test_schema_util.py +++ b/metadata-ingestion/tests/unit/test_schema_util.py @@ -63,6 +63,25 @@ } """ +SCHEMA_WITH_OPTIONAL_FIELD_VIA_FIXED_TYPE: str = json.dumps( + { + "type": "record", + "name": "__struct_", + "fields": [ + { + "name": "value", + "type": { + "type": "fixed", + "name": "__fixed_d9d2d051916045d9975d6c573aaabb89", + "size": 4, + "native_data_type": "fixed[4]", + "_nullable": True, + }, + }, + ], + } +) + def log_field_paths(fields: List[SchemaField]) -> None: logger.debug('FieldPaths=\n"' + '",\n"'.join(f.fieldPath for f in fields) + '"') @@ -93,11 +112,13 @@ def assert_field_paths_match( SCHEMA_WITH_OPTIONAL_FIELD_VIA_UNION_TYPE, SCHEMA_WITH_OPTIONAL_FIELD_VIA_UNION_TYPE_NULL_ISNT_FIRST_IN_UNION, SCHEMA_WITH_OPTIONAL_FIELD_VIA_PRIMITIVE_TYPE, + SCHEMA_WITH_OPTIONAL_FIELD_VIA_FIXED_TYPE, ], ids=[ "optional_field_via_union_type", "optional_field_via_union_null_not_first", "optional_field_via_primitive", + "optional_field_via_fixed", ], ) def test_avro_schema_to_mce_fields_events_with_nullable_fields(schema):