Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest/mongodb): Ingest databases as containers #11178

Merged
merged 5 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 102 additions & 78 deletions metadata-ingestion/src/datahub/ingestion/source/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
from datahub.emitter.mce_builder import (
make_data_platform_urn,
make_dataplatform_instance_urn,
make_dataset_urn_with_platform_instance,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.mcp_builder import (
DatabaseKey,
add_dataset_to_container,
gen_containers,
)
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
SourceCapability,
Expand All @@ -32,6 +36,7 @@
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
from datahub.ingestion.source.schema_inference.object import (
SchemaDescription,
construct_schema,
Expand Down Expand Up @@ -64,6 +69,7 @@
DataPlatformInstanceClass,
DatasetPropertiesClass,
)
from datahub.metadata.urns import DatasetUrn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -263,6 +269,7 @@ class MongoDBSource(StatefulIngestionSourceBase):
config: MongoDBConfig
report: MongoDBSourceReport
mongo_client: MongoClient
platform: str = "mongodb"

def __init__(self, ctx: PipelineContext, config: MongoDBConfig):
super().__init__(config, ctx)
Expand All @@ -282,7 +289,9 @@ def __init__(self, ctx: PipelineContext, config: MongoDBConfig):
}

# See https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes
self.mongo_client = MongoClient(self.config.connect_uri, datetime_conversion="DATETIME_AUTO", **options) # type: ignore
self.mongo_client = MongoClient(
self.config.connect_uri, datetime_conversion="DATETIME_AUTO", **options
) # type: ignore

# This cheaply tests the connection. For details, see
# https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient
Expand Down Expand Up @@ -351,8 +360,6 @@ def get_field_type(
return SchemaFieldDataType(type=TypeClass())

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
platform = "mongodb"

database_names: List[str] = self.mongo_client.list_database_names()

# traverse databases in sorted order so output is consistent
Expand All @@ -364,8 +371,19 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
continue

database = self.mongo_client[database_name]
collection_names: List[str] = database.list_collection_names()
database_key = DatabaseKey(
database=database_name,
platform=self.platform,
instance=self.config.platform_instance,
env=self.config.env,
)
yield from gen_containers(
container_key=database_key,
name=database_name,
sub_types=[DatasetContainerSubTypes.DATABASE],
)

collection_names: List[str] = database.list_collection_names()
# traverse collections in sorted order so output is consistent
for collection_name in sorted(collection_names):
dataset_name = f"{database_name}.{collection_name}"
Expand All @@ -374,9 +392,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
self.report.report_dropped(dataset_name)
continue

dataset_urn = make_dataset_urn_with_platform_instance(
platform=platform,
name=dataset_name,
dataset_urn = DatasetUrn.create_from_ids(
platform_id=self.platform,
table_name=dataset_name,
env=self.config.env,
platform_instance=self.config.platform_instance,
)
Expand All @@ -385,9 +403,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
data_platform_instance = None
if self.config.platform_instance:
data_platform_instance = DataPlatformInstanceClass(
platform=make_data_platform_urn(platform),
platform=make_data_platform_urn(self.platform),
instance=make_dataplatform_instance_urn(
platform, self.config.platform_instance
self.platform, self.config.platform_instance
),
)

Expand All @@ -397,83 +415,21 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
)

schema_metadata: Optional[SchemaMetadata] = None

if self.config.enableSchemaInference:
assert self.config.maxDocumentSize is not None
collection_schema = construct_schema_pymongo(
database[collection_name],
delimiter=".",
use_random_sampling=self.config.useRandomSampling,
max_document_size=self.config.maxDocumentSize,
should_add_document_size_filter=self.should_add_document_size_filter(),
sample_size=self.config.schemaSamplingSize,
)

# initialize the schema for the collection
canonical_schema: List[SchemaField] = []
max_schema_size = self.config.maxSchemaSize
collection_schema_size = len(collection_schema.values())
collection_fields: Union[
List[SchemaDescription], ValuesView[SchemaDescription]
] = collection_schema.values()
assert max_schema_size is not None
if collection_schema_size > max_schema_size:
# downsample the schema, using frequency as the sort key
self.report.report_warning(
title="Too many schema fields",
message=f"Downsampling the collection schema because it has too many schema fields. Configured threshold is {max_schema_size}",
context=f"Schema Size: {collection_schema_size}, Collection: {dataset_urn}",
)
# Add this information to the custom properties so user can know they are looking at downsampled schema
dataset_properties.customProperties[
"schema.downsampled"
] = "True"
dataset_properties.customProperties[
"schema.totalFields"
] = f"{collection_schema_size}"

logger.debug(
f"Size of collection fields = {len(collection_fields)}"
)
# append each schema field (sort so output is consistent)
for schema_field in sorted(
collection_fields,
key=lambda x: (
-x["count"],
x["delimited_name"],
), # Negate `count` for descending order, `delimited_name` stays the same for ascending
)[0:max_schema_size]:
field = SchemaField(
fieldPath=schema_field["delimited_name"],
nativeDataType=self.get_pymongo_type_string(
schema_field["type"], dataset_name
),
type=self.get_field_type(
schema_field["type"], dataset_name
),
description=None,
nullable=schema_field["nullable"],
recursive=False,
)
canonical_schema.append(field)

# create schema metadata object for collection
schema_metadata = SchemaMetadata(
schemaName=collection_name,
platform=f"urn:li:dataPlatform:{platform}",
version=0,
hash="",
platformSchema=SchemalessClass(),
fields=canonical_schema,
schema_metadata = self._infer_schema_metadata(
collection=database[collection_name],
dataset_urn=dataset_urn,
dataset_properties=dataset_properties,
)

# TODO: use list_indexes() or index_information() to get index information
# See https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.list_indexes.

yield from add_dataset_to_container(database_key, dataset_urn.urn())
yield from [
mcp.as_workunit()
for mcp in MetadataChangeProposalWrapper.construct_many(
entityUrn=dataset_urn,
entityUrn=dataset_urn.urn(),
aspects=[
schema_metadata,
dataset_properties,
Expand All @@ -482,6 +438,74 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
)
]

def _infer_schema_metadata(
self,
collection: pymongo.collection.Collection,
dataset_urn: DatasetUrn,
dataset_properties: DatasetPropertiesClass,
) -> SchemaMetadata:
assert self.config.maxDocumentSize is not None
collection_schema = construct_schema_pymongo(
collection,
delimiter=".",
use_random_sampling=self.config.useRandomSampling,
max_document_size=self.config.maxDocumentSize,
should_add_document_size_filter=self.should_add_document_size_filter(),
sample_size=self.config.schemaSamplingSize,
)

# initialize the schema for the collection
canonical_schema: List[SchemaField] = []
max_schema_size = self.config.maxSchemaSize
collection_schema_size = len(collection_schema.values())
collection_fields: Union[
List[SchemaDescription], ValuesView[SchemaDescription]
] = collection_schema.values()
assert max_schema_size is not None
if collection_schema_size > max_schema_size:
# downsample the schema, using frequency as the sort key
self.report.report_warning(
title="Too many schema fields",
message=f"Downsampling the collection schema because it has too many schema fields. Configured threshold is {max_schema_size}",
context=f"Schema Size: {collection_schema_size}, Collection: {dataset_urn}",
)
# Add this information to the custom properties so user can know they are looking at downsampled schema
dataset_properties.customProperties["schema.downsampled"] = "True"
dataset_properties.customProperties[
"schema.totalFields"
] = f"{collection_schema_size}"

logger.debug(f"Size of collection fields = {len(collection_fields)}")
# append each schema field (sort so output is consistent)
for schema_field in sorted(
collection_fields,
key=lambda x: (
-x["count"],
x["delimited_name"],
), # Negate `count` for descending order, `delimited_name` stays the same for ascending
)[0:max_schema_size]:
field = SchemaField(
fieldPath=schema_field["delimited_name"],
nativeDataType=self.get_pymongo_type_string(
schema_field["type"], dataset_urn.name
),
type=self.get_field_type(schema_field["type"], dataset_urn.name),
description=None,
nullable=schema_field["nullable"],
recursive=False,
)
canonical_schema.append(field)

# create schema metadata object for collection
return SchemaMetadata(
schemaName=collection.name,
platform=f"urn:li:dataPlatform:{self.platform}",
version=0,
hash="",
platformSchema=SchemalessClass(),
fields=canonical_schema,
)

def is_server_version_gte_4_4(self) -> bool:
try:
server_version = self.mongo_client.server_info().get("versionArray")
Expand Down
Loading
Loading