Skip to content

Commit

Permalink
feat(ingest): add classification for sql sources (#10013)
Browse files Browse the repository at this point in the history
Co-authored-by: Harshal Sheth <[email protected]>
  • Loading branch information
mayurinehate and hsheth2 authored Mar 12, 2024
1 parent 28f16aa commit 2de0e62
Show file tree
Hide file tree
Showing 19 changed files with 976 additions and 1,030 deletions.
8 changes: 6 additions & 2 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@
"acryl-sqlglot==22.3.1.dev3",
}

classification_lib = {
"acryl-datahub-classify==0.0.9",
}

sql_common = (
{
# Required for all SQL sources.
Expand All @@ -121,6 +125,7 @@
}
| usage_common
| sqlglot_lib
| classification_lib
)

sqllineage_lib = {
Expand Down Expand Up @@ -190,8 +195,7 @@
"pandas",
"cryptography",
"msal",
"acryl-datahub-classify==0.0.9",
}
} | classification_lib

trino = {
"trino[sqlalchemy]>=0.308",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def get_columns_to_classify(
f"Skipping column {dataset_name}.{schema_field.fieldPath} from classification"
)
continue

# TODO: Let's auto-skip passing sample_data for complex(array/struct) columns
# for initial rollout

column_infos.append(
ColumnInfo(
metadata=Metadata(
Expand All @@ -243,9 +247,11 @@ def get_columns_to_classify(
"Dataset_Name": dataset_name,
}
),
values=sample_data[schema_field.fieldPath]
if schema_field.fieldPath in sample_data.keys()
else [],
values=(
sample_data[schema_field.fieldPath]
if schema_field.fieldPath in sample_data.keys()
else []
),
)
)

Expand Down
136 changes: 136 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/sql/data_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import logging
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Union

import sqlalchemy as sa
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.row import LegacyRow

from datahub.ingestion.api.closeable import Closeable

logger: logging.Logger = logging.getLogger(__name__)


class DataReader(Closeable):
@abstractmethod
def get_sample_data_for_column(
self, table_id: List[str], column_name: str, sample_size: int = 100
) -> list:
pass

@abstractmethod
def get_sample_data_for_table(
self, table_id: List[str], sample_size: int = 100
) -> Dict[str, list]:
pass


class SqlAlchemyTableDataReader(DataReader):
@staticmethod
def create(inspector: Inspector) -> "SqlAlchemyTableDataReader":
return SqlAlchemyTableDataReader(conn=inspector.bind)

def __init__(
self,
conn: Union[Engine, Connection],
) -> None:
# TODO: How can this use a connection pool instead ?
self.engine = conn.engine.connect()

def _table(self, table_id: List[str]) -> sa.Table:
return sa.Table(
table_id[-1],
sa.MetaData(),
schema=table_id[-2] if len(table_id) > 1 else None,
)

def get_sample_data_for_column(
self, table_id: List[str], column_name: str, sample_size: int = 100
) -> list:
"""
Fetches non-null column values, upto <sample_size> count
Args:
table_id: Table name identifier. One of
- [<db_name>, <schema_name>, <table_name>] or
- [<schema_name>, <table_name>] or
- [<table_name>]
column: Column name
Returns:
list of column values
"""

table = self._table(table_id)
query: Any
ignore_null_condition = sa.column(column_name).is_(None)
# limit doesn't compile properly for oracle so we will append rownum to query string later
if self.engine.dialect.name.lower() == "oracle":
raw_query = (
sa.select([sa.column(column_name)])
.select_from(table)
.where(sa.not_(ignore_null_condition))
)

query = str(
raw_query.compile(self.engine, compile_kwargs={"literal_binds": True})
)
query += "\nAND ROWNUM <= %d" % sample_size
else:
query = (
sa.select([sa.column(column_name)])
.select_from(table)
.where(sa.not_(ignore_null_condition))
.limit(sample_size)
)
query_results = self.engine.execute(query)

return [x[column_name] for x in query_results.fetchall()]

def get_sample_data_for_table(
self, table_id: List[str], sample_size: int = 100
) -> Dict[str, list]:
"""
Fetches table values, upto <sample_size>*1.2 count
Args:
table_id: Table name identifier. One of
- [<db_name>, <schema_name>, <table_name>] or
- [<schema_name>, <table_name>] or
- [<table_name>]
Returns:
dictionary of (column name -> list of column values)
"""
column_values: Dict[str, list] = defaultdict(list)
table = self._table(table_id)

# Ideally we do not want null values in sample data for a column.
# However that would require separate query per column and
# that would be expensiv. To compensate for possibility
# of some null values in collected sample, we fetch extra (20% more)
# rows than configured sample_size.
sample_size = int(sample_size * 1.2)

query: Any

# limit doesn't compile properly for oracle so we will append rownum to query string later
if self.engine.dialect.name.lower() == "oracle":
raw_query = sa.select([sa.text("*")]).select_from(table)

query = str(
raw_query.compile(self.engine, compile_kwargs={"literal_binds": True})
)
query += "\nAND ROWNUM <= %d" % sample_size
else:
query = sa.select([sa.text("*")]).select_from(table).limit(sample_size)
query_results = self.engine.execute(query)

# Not ideal - creates a parallel structure in column_values. Can we use pandas here ?
for row in query_results.fetchall():
if isinstance(row, LegacyRow):
for col, col_value in row.items():
column_values[col].append(col_value)

return column_values

def close(self) -> None:
self.engine.close()
119 changes: 94 additions & 25 deletions metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import datetime
import logging
import traceback
Expand Down Expand Up @@ -43,10 +44,18 @@
TestConnectionReport,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.glossary.classification_mixin import (
ClassificationHandler,
ClassificationReportMixin,
)
from datahub.ingestion.source.common.subtypes import (
DatasetContainerSubTypes,
DatasetSubTypes,
)
from datahub.ingestion.source.sql.data_reader import (
DataReader,
SqlAlchemyTableDataReader,
)
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig
from datahub.ingestion.source.sql.sql_utils import (
add_table_to_schema_container,
Expand Down Expand Up @@ -120,7 +129,7 @@


@dataclass
class SQLSourceReport(StaleEntityRemovalSourceReport):
class SQLSourceReport(StaleEntityRemovalSourceReport, ClassificationReportMixin):
tables_scanned: int = 0
views_scanned: int = 0
entities_profiled: int = 0
Expand Down Expand Up @@ -314,6 +323,7 @@ def __init__(self, config: SQLCommonConfig, ctx: PipelineContext, platform: str)
self.report: SQLSourceReport = SQLSourceReport()
self.profile_metadata_info: ProfileMetadata = ProfileMetadata()

self.classification_handler = ClassificationHandler(self.config, self.report)
config_report = {
config_option: config.dict().get(config_option)
for config_option in config_options_to_report
Expand Down Expand Up @@ -643,38 +653,61 @@ def get_foreign_key_metadata(
fk_dict["name"], foreign_fields, source_fields, foreign_dataset
)

def make_data_reader(self, inspector: Inspector) -> Optional[DataReader]:
"""
Subclasses can override this with source-specific data reader
if source provides clause to pick random sample instead of current
limit-based sample
"""
if (
self.classification_handler
and self.classification_handler.is_classification_enabled()
):
return SqlAlchemyTableDataReader.create(inspector)

return None

def loop_tables( # noqa: C901
self,
inspector: Inspector,
schema: str,
sql_config: SQLCommonConfig,
) -> Iterable[Union[SqlWorkUnit, MetadataWorkUnit]]:
tables_seen: Set[str] = set()
try:
for table in inspector.get_table_names(schema):
dataset_name = self.get_identifier(
schema=schema, entity=table, inspector=inspector
)

if dataset_name not in tables_seen:
tables_seen.add(dataset_name)
else:
logger.debug(f"{dataset_name} has already been seen, skipping...")
continue

self.report.report_entity_scanned(dataset_name, ent_type="table")
if not sql_config.table_pattern.allowed(dataset_name):
self.report.report_dropped(dataset_name)
continue

try:
yield from self._process_table(
dataset_name, inspector, schema, table, sql_config
data_reader = self.make_data_reader(inspector)
with (data_reader or contextlib.nullcontext()):
try:
for table in inspector.get_table_names(schema):
dataset_name = self.get_identifier(
schema=schema, entity=table, inspector=inspector
)
except Exception as e:
self.warn(logger, f"{schema}.{table}", f"Ingestion error: {e}")
except Exception as e:
self.error(logger, f"{schema}", f"Tables error: {e}")

if dataset_name not in tables_seen:
tables_seen.add(dataset_name)
else:
logger.debug(
f"{dataset_name} has already been seen, skipping..."
)
continue

self.report.report_entity_scanned(dataset_name, ent_type="table")
if not sql_config.table_pattern.allowed(dataset_name):
self.report.report_dropped(dataset_name)
continue

try:
yield from self._process_table(
dataset_name,
inspector,
schema,
table,
sql_config,
data_reader,
)
except Exception as e:
self.warn(logger, f"{schema}.{table}", f"Ingestion error: {e}")
except Exception as e:
self.error(logger, f"{schema}", f"Tables error: {e}")

def add_information_for_schema(self, inspector: Inspector, schema: str) -> None:
pass
Expand All @@ -691,6 +724,7 @@ def _process_table(
schema: str,
table: str,
sql_config: SQLCommonConfig,
data_reader: Optional[DataReader],
) -> Iterable[Union[SqlWorkUnit, MetadataWorkUnit]]:
columns = self._get_columns(dataset_name, inspector, schema, table)
dataset_urn = make_dataset_urn_with_platform_instance(
Expand Down Expand Up @@ -740,6 +774,8 @@ def _process_table(
foreign_keys,
schema_fields,
)
self._classify(dataset_name, schema, table, data_reader, schema_metadata)

dataset_snapshot.aspects.append(schema_metadata)
if self.config.include_view_lineage:
self.schema_resolver.add_schema_metadata(dataset_urn, schema_metadata)
Expand Down Expand Up @@ -770,6 +806,39 @@ def _process_table(
domain_registry=self.domain_registry,
)

def _classify(
self,
dataset_name: str,
schema: str,
table: str,
data_reader: Optional[DataReader],
schema_metadata: SchemaMetadata,
) -> None:
try:
if (
self.classification_handler.is_classification_enabled_for_table(
dataset_name
)
and data_reader
):
self.classification_handler.classify_schema_fields(
dataset_name,
schema_metadata,
data_reader.get_sample_data_for_table(
table_id=[schema, table],
sample_size=self.config.classification.sample_size,
),
)
except Exception as e:
logger.debug(
f"Failed to classify table columns for {dataset_name} due to error -> {e}",
exc_info=e,
)
self.report.report_warning(
"Failed to classify table columns",
dataset_name,
)

def get_database_properties(
self, inspector: Inspector, database: str
) -> Optional[Dict[str, str]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
LowerCaseDatasetUrnConfigMixin,
)
from datahub.configuration.validate_field_removal import pydantic_removed_field
from datahub.ingestion.glossary.classification_mixin import (
ClassificationSourceConfigMixin,
)
from datahub.ingestion.source.ge_profiling_config import GEProfilingConfig
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StatefulStaleMetadataRemovalConfig,
Expand All @@ -29,6 +32,7 @@ class SQLCommonConfig(
DatasetSourceConfigMixin,
LowerCaseDatasetUrnConfigMixin,
LineageConfig,
ClassificationSourceConfigMixin,
):
options: dict = pydantic.Field(
default_factory=dict,
Expand Down
Loading

0 comments on commit 2de0e62

Please sign in to comment.