Skip to content

Commit

Permalink
feat(ingest): add classification to bigquery, redshift (#10031)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate authored Mar 14, 2024
1 parent 239ae31 commit 77c72da
Show file tree
Hide file tree
Showing 34 changed files with 692 additions and 266 deletions.
63 changes: 39 additions & 24 deletions metadata-ingestion/docs/dev_guides/classification.md

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@
| {
*sqlglot_lib,
"google-cloud-datacatalog-lineage==0.2.2",
},
}
| classification_lib,
"clickhouse": sql_common | clickhouse_common,
"clickhouse-usage": sql_common | usage_common | clickhouse_common,
"datahub-lineage-file": set(),
Expand Down Expand Up @@ -370,6 +371,8 @@
| redshift_common
| usage_common
| sqlglot_lib
| classification_lib
| {"db-dtypes"} # Pandas extension data types
| {"cachetools"},
"s3": {*s3_base, *data_lake_profiling},
"gcs": {*s3_base, *data_lake_profiling},
Expand Down
1 change: 1 addition & 0 deletions metadata-ingestion/src/datahub/ingestion/api/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class SourceCapability(Enum):
TAGS = "Extract Tags"
SCHEMA_METADATA = "Schema Metadata"
CONTAINERS = "Asset Containers"
CLASSIFICATION = "Classification"


@dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import concurrent.futures
import logging
from dataclasses import dataclass, field
from functools import partial
from math import ceil
from typing import Dict, Iterable, List, Optional
from typing import Callable, Dict, Iterable, List, Optional, Union

from datahub_classify.helper_classes import ColumnInfo, Metadata
from pydantic import Field

from datahub.configuration.common import ConfigModel, ConfigurationError
from datahub.emitter.mce_builder import get_sys_time, make_term_urn, make_user_urn
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.glossary.classifier import ClassificationConfig, Classifier
from datahub.ingestion.glossary.classifier_registry import classifier_registry
from datahub.ingestion.source.sql.data_reader import DataReader
from datahub.metadata.com.linkedin.pegasus2avro.common import (
AuditStamp,
GlossaryTermAssociation,
Expand All @@ -25,9 +29,12 @@

@dataclass
class ClassificationReportMixin:

num_tables_fetch_sample_values_failed: int = 0

num_tables_classification_attempted: int = 0
num_tables_classification_failed: int = 0
num_tables_classified: int = 0
num_tables_classification_found: int = 0

info_types_detected: LossyDict[str, LossyList[str]] = field(
default_factory=LossyDict
Expand Down Expand Up @@ -99,8 +106,22 @@ def classify_schema_fields(
self,
dataset_name: str,
schema_metadata: SchemaMetadata,
sample_data: Dict[str, list],
sample_data: Union[Dict[str, list], Callable[[], Dict[str, list]]],
) -> None:

if not isinstance(sample_data, Dict):
try:
# TODO: In future, sample_data fetcher can be lazily called if classification
# requires values as prediction factor
sample_data = sample_data()
except Exception as e:
self.report.num_tables_fetch_sample_values_failed += 1
logger.warning(
f"Failed to get sample values for dataset. Make sure you have granted SELECT permissions on dataset. {dataset_name}",
)
sample_data = dict()
logger.debug("Error", exc_info=e)

column_infos = self.get_columns_to_classify(
dataset_name, schema_metadata, sample_data
)
Expand Down Expand Up @@ -137,7 +158,7 @@ def classify_schema_fields(
)

if field_terms:
self.report.num_tables_classified += 1
self.report.num_tables_classification_found += 1
self.populate_terms_in_schema_metadata(schema_metadata, field_terms)

def update_field_terms(
Expand Down Expand Up @@ -234,8 +255,11 @@ def get_columns_to_classify(
)
continue

# TODO: Let's auto-skip passing sample_data for complex(array/struct) columns
# for initial rollout
# As a result of custom field path specification e.g. [version=2.0].[type=struct].[type=struct].service'
# Sample values for a nested field (an array , union or struct) are not read / passed in classifier correctly.
# TODO: Fix this behavior for nested fields. This would probably involve:
# 1. Preprocessing field path spec v2 back to native field representation. (without [*] constructs)
# 2. Preprocessing retrieved structured sample data to pass in sample values correctly for nested fields.

column_infos.append(
ColumnInfo(
Expand All @@ -256,3 +280,47 @@ def get_columns_to_classify(
)

return column_infos


def classification_workunit_processor(
table_wu_generator: Iterable[MetadataWorkUnit],
classification_handler: ClassificationHandler,
data_reader: Optional[DataReader],
table_id: List[str],
data_reader_kwargs: dict = {},
) -> Iterable[MetadataWorkUnit]:
table_name = ".".join(table_id)
if not classification_handler.is_classification_enabled_for_table(table_name):
yield from table_wu_generator
for wu in table_wu_generator:
maybe_schema_metadata = wu.get_aspect_of_type(SchemaMetadata)
if maybe_schema_metadata:
try:
classification_handler.classify_schema_fields(
table_name,
maybe_schema_metadata,
(
partial(
data_reader.get_sample_data_for_table,
table_id,
classification_handler.config.classification.sample_size
* 1.2,
**data_reader_kwargs,
)
if data_reader
else dict()
),
)
yield MetadataChangeProposalWrapper(
aspect=maybe_schema_metadata, entityUrn=wu.get_urn()
).as_workunit(
is_primary_source=wu.is_primary_source,
)
except Exception as e:
logger.debug(
f"Failed to classify table columns for {table_name} due to error -> {e}",
exc_info=e,
)
yield wu
else:
yield wu
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,16 @@
TestConnectionReport,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.glossary.classification_mixin import (
ClassificationHandler,
classification_workunit_processor,
)
from datahub.ingestion.source.bigquery_v2.bigquery_audit import (
BigqueryTableIdentifier,
BigQueryTableRef,
)
from datahub.ingestion.source.bigquery_v2.bigquery_config import BigQueryV2Config
from datahub.ingestion.source.bigquery_v2.bigquery_data_reader import BigQueryDataReader
from datahub.ingestion.source.bigquery_v2.bigquery_helper import (
unquote_and_decode_unicode_escape_seq,
)
Expand Down Expand Up @@ -167,6 +172,11 @@ def cleanup(config: BigQueryV2Config) -> None:
"Optionally enabled via `stateful_ingestion.remove_stale_metadata`",
supported=True,
)
@capability(
SourceCapability.CLASSIFICATION,
"Optionally enabled via `classification.enabled`",
supported=True,
)
class BigqueryV2Source(StatefulIngestionSourceBase, TestableSource):
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types
BIGQUERY_FIELD_TYPE_MAPPINGS: Dict[
Expand Down Expand Up @@ -214,6 +224,7 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config):
super(BigqueryV2Source, self).__init__(config, ctx)
self.config: BigQueryV2Config = config
self.report: BigQueryV2Report = BigQueryV2Report()
self.classification_handler = ClassificationHandler(self.config, self.report)
self.platform: str = "bigquery"

BigqueryTableIdentifier._BIGQUERY_DEFAULT_SHARDED_TABLE_REGEX = (
Expand All @@ -227,6 +238,12 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config):
)
self.sql_parser_schema_resolver = self._init_schema_resolver()

self.data_reader: Optional[BigQueryDataReader] = None
if self.classification_handler.is_classification_enabled():
self.data_reader = BigQueryDataReader.create(
self.config.get_bigquery_client()
)

redundant_lineage_run_skip_handler: Optional[
RedundantLineageRunSkipHandler
] = None
Expand Down Expand Up @@ -713,6 +730,7 @@ def _process_schema(
)

columns = None

if (
self.config.include_tables
or self.config.include_views
Expand All @@ -732,12 +750,27 @@ def _process_schema(

for table in db_tables[dataset_name]:
table_columns = columns.get(table.name, []) if columns else []
yield from self._process_table(
table_wu_generator = self._process_table(
table=table,
columns=table_columns,
project_id=project_id,
dataset_name=dataset_name,
)
yield from classification_workunit_processor(
table_wu_generator,
self.classification_handler,
self.data_reader,
[project_id, dataset_name, table.name],
data_reader_kwargs=dict(
sample_size_percent=(
self.config.classification.sample_size
* 1.2
/ table.rows_count
if table.rows_count
else None
)
),
)
elif self.store_table_refs:
# Need table_refs to calculate lineage and usage
for table_item in self.bigquery_data_dictionary.list_tables(
Expand Down Expand Up @@ -1071,14 +1104,16 @@ def gen_dataset_workunits(
)

yield self.gen_schema_metadata(
dataset_urn, table, columns, str(datahub_dataset_name)
dataset_urn, table, columns, datahub_dataset_name
)

dataset_properties = DatasetProperties(
name=datahub_dataset_name.get_table_display_name(),
description=unquote_and_decode_unicode_escape_seq(table.comment)
if table.comment
else "",
description=(
unquote_and_decode_unicode_escape_seq(table.comment)
if table.comment
else ""
),
qualifiedName=str(datahub_dataset_name),
created=(
TimeStamp(time=int(table.created.timestamp() * 1000))
Expand Down Expand Up @@ -1238,10 +1273,10 @@ def gen_schema_metadata(
dataset_urn: str,
table: Union[BigqueryTable, BigqueryView, BigqueryTableSnapshot],
columns: List[BigqueryColumn],
dataset_name: str,
dataset_name: BigqueryTableIdentifier,
) -> MetadataWorkUnit:
schema_metadata = SchemaMetadata(
schemaName=dataset_name,
schemaName=str(dataset_name),
platform=make_data_platform_urn(self.platform),
version=0,
hash="",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from datahub.configuration.common import AllowDenyPattern, ConfigModel
from datahub.configuration.validate_field_removal import pydantic_removed_field
from datahub.ingestion.glossary.classification_mixin import (
ClassificationSourceConfigMixin,
)
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulLineageConfigMixin,
Expand Down Expand Up @@ -64,9 +67,9 @@ def __init__(self, **data: Any):
)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self._credentials_path

def get_bigquery_client(config) -> bigquery.Client:
client_options = config.extra_client_options
return bigquery.Client(config.project_on_behalf, **client_options)
def get_bigquery_client(self) -> bigquery.Client:
client_options = self.extra_client_options
return bigquery.Client(self.project_on_behalf, **client_options)

def make_gcp_logging_client(
self, project_id: Optional[str] = None
Expand Down Expand Up @@ -96,6 +99,7 @@ class BigQueryV2Config(
StatefulUsageConfigMixin,
StatefulLineageConfigMixin,
StatefulProfilingConfigMixin,
ClassificationSourceConfigMixin,
):
project_id_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import logging
from collections import defaultdict
from typing import Dict, List, Optional

from google.cloud import bigquery

from datahub.ingestion.source.sql.data_reader import DataReader
from datahub.utilities.perf_timer import PerfTimer

logger = logging.Logger(__name__)


class BigQueryDataReader(DataReader):
@staticmethod
def create(
client: bigquery.Client,
) -> "BigQueryDataReader":
return BigQueryDataReader(client)

def __init__(
self,
client: bigquery.Client,
) -> None:
self.client = client

def get_sample_data_for_table(
self,
table_id: List[str],
sample_size: int,
*,
sample_size_percent: Optional[float] = None,
filter: Optional[str] = None,
) -> Dict[str, list]:
"""
table_id should be in the form [project, dataset, schema]
"""

assert len(table_id) == 3
project = table_id[0]
dataset = table_id[1]
table_name = table_id[2]

column_values: Dict[str, list] = defaultdict(list)
if sample_size_percent is None:
return column_values
# Ideally we always know the actual row count.
# The alternative to perform limit query scans entire BQ table
# and is never a recommended option due to cost factor, unless
# additional filter clause (e.g. where condition on partition) is available.

logger.debug(
f"Collecting sample values for table {project}.{dataset}.{table_name}"
)
with PerfTimer() as timer:
sample_pc = sample_size_percent * 100
# TODO: handle for sharded+compulsory partitioned tables
sql = (
f"SELECT * FROM `{project}.{dataset}.{table_name}` "
+ f"TABLESAMPLE SYSTEM ({sample_pc:.8f} percent)"
)
# Ref: https://cloud.google.com/bigquery/docs/samples/bigquery-query-results-dataframe
df = self.client.query_and_wait(sql).to_dataframe()
time_taken = timer.elapsed_seconds()
logger.debug(
f"Finished collecting sample values for table {project}.{dataset}.{table_name};"
f"{df.shape[0]} rows; took {time_taken:.3f} seconds"
)

return df.to_dict(orient="list")

def close(self) -> None:
self.client.close()
Loading

0 comments on commit 77c72da

Please sign in to comment.