Skip to content

Commit

Permalink
feat(ingest): add classification for sql sources
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate committed Mar 8, 2024
1 parent 6e8a2eb commit 764ea8a
Show file tree
Hide file tree
Showing 13 changed files with 316 additions and 390 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def get_columns_to_classify(
f"Skipping column {dataset_name}.{schema_field.fieldPath} from classification"
)
continue

# TODO: Let's auto-skip passing sample_data for complex(array/struct) columns
# for initial rollout

column_infos.append(
ColumnInfo(
metadata=Metadata(
Expand All @@ -243,9 +247,11 @@ def get_columns_to_classify(
"Dataset_Name": dataset_name,
}
),
values=sample_data[schema_field.fieldPath]
if schema_field.fieldPath in sample_data.keys()
else [],
values=(
sample_data[schema_field.fieldPath]
if schema_field.fieldPath in sample_data.keys()
else []
),
)
)

Expand Down
135 changes: 135 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/sql/data_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Union

import sqlalchemy as sa
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.row import LegacyRow


class DataReader(ABC):
@abstractmethod
def get_data_for_column(
self, table_id: List[str], column_name: str, sample_size: int = 100
) -> list:
pass

@abstractmethod
def get_data_for_table(
self, table_id: List[str], sample_size: int = 100
) -> Dict[str, list]:
pass

@abstractmethod
def close(self) -> None:
pass


class GenericSqlTableDataReader(DataReader):
@staticmethod
def create(inspector: Inspector) -> "GenericSqlTableDataReader":
return GenericSqlTableDataReader(conn=inspector.bind)

def __init__(
self,
conn: Union[Engine, Connection],
) -> None:
# TODO: How can this use a connection pool instead ?
self.engine = conn.engine.connect()

def _table(self, table_id: List[str]) -> sa.Table:
return sa.Table(table_id[-1], sa.MetaData(), schema=table_id[-2])

def get_data_for_column(
self, table_id: List[str], column_name: str, sample_size: int = 100
) -> list:
"""
Fetches non-null column values, upto <sample_size> count
Args:
table_id: Table name identifier. One of
- [<db_name>, <schema_name>, <table_name>] or
- [<schema_name>, <table_name>] or
- [<table_name>]
column: Column name
Returns:
list of column values
"""

table = self._table(table_id)
query: Any
ignore_null_condition = sa.column(column_name).is_(None)
# limit doesn't compile properly for oracle so we will append rownum to query string later
if self.engine.dialect.name.lower() == "oracle":
raw_query = (
sa.select([sa.column(column_name)])
.select_from(table)
.where(sa.not_(ignore_null_condition))
)

query = str(
raw_query.compile(self.engine, compile_kwargs={"literal_binds": True})
)
query += "\nAND ROWNUM <= %d" % sample_size
else:
query = (
sa.select([sa.column(column_name)])
.select_from(table)
.where(sa.not_(ignore_null_condition))
.limit(sample_size)
)
query_results = self.engine.execute(query)

return [x[column_name] for x in query_results.fetchall()]

def get_data_for_table(
self, table_id: List[str], sample_size: int = 100
) -> Dict[str, list]:
"""
Fetches table values, upto <sample_size>*1.2 count
Args:
table_id: Table name identifier. One of
- [<db_name>, <schema_name>, <table_name>] or
- [<schema_name>, <table_name>] or
- [<table_name>]
Returns:
dictionary of (column name -> list of column values)
"""
column_values: Dict[str, list] = defaultdict(list)
table = self._table(table_id)

# Ideally we do not want null values in sample data for a column.
# However that would require separate query per column and
# that would be expensiv. To compensate for possibility
# of some null values in collected sample, we fetch extra (20% more)
# rows than configured sample_size.
sample_size = int(sample_size * 1.2)

query: Any

# limit doesn't compile properly for oracle so we will append rownum to query string later
if self.engine.dialect.name.lower() == "oracle":
raw_query = sa.select([sa.text("*")]).select_from(table)

query = str(
raw_query.compile(self.engine, compile_kwargs={"literal_binds": True})
)
query += "\nAND ROWNUM <= %d" % sample_size
else:
query = sa.select([sa.text("*")]).select_from(table).limit(sample_size)
query_results = self.engine.execute(query)
# Not ideal - creates a parallel structure. Can we use pandas here ?
for row in query_results.fetchall():
if isinstance(row, LegacyRow):
for col, col_value in row.items():
column_values[col].append(col_value)
else:
import pdb

pdb.set_trace()

return column_values

def close(self) -> None:
if hasattr(self, "engine"):
self.engine.close()
62 changes: 60 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,18 @@
TestConnectionReport,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.glossary.classification_mixin import (
ClassificationHandler,
ClassificationReportMixin,
)
from datahub.ingestion.source.common.subtypes import (
DatasetContainerSubTypes,
DatasetSubTypes,
)
from datahub.ingestion.source.sql.data_reader import (
DataReader,
GenericSqlTableDataReader,
)
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig
from datahub.ingestion.source.sql.sql_utils import (
add_table_to_schema_container,
Expand Down Expand Up @@ -120,7 +128,7 @@


@dataclass
class SQLSourceReport(StaleEntityRemovalSourceReport):
class SQLSourceReport(StaleEntityRemovalSourceReport, ClassificationReportMixin):
tables_scanned: int = 0
views_scanned: int = 0
entities_profiled: int = 0
Expand Down Expand Up @@ -314,6 +322,7 @@ def __init__(self, config: SQLCommonConfig, ctx: PipelineContext, platform: str)
self.report: SQLSourceReport = SQLSourceReport()
self.profile_metadata_info: ProfileMetadata = ProfileMetadata()

self.classification_handler = ClassificationHandler(self.config, self.report)
config_report = {
config_option: config.dict().get(config_option)
for config_option in config_options_to_report
Expand Down Expand Up @@ -643,13 +652,28 @@ def get_foreign_key_metadata(
fk_dict["name"], foreign_fields, source_fields, foreign_dataset
)

def init_data_reader(self, inspector: Inspector) -> Optional[DataReader]:
"""
Subclasses can override this with source-specific data reader
if source provides clause to pick random sample instead of current
limit-based sample
"""
if (
self.classification_handler
and self.classification_handler.is_classification_enabled()
):
return GenericSqlTableDataReader.create(inspector)

return None

def loop_tables( # noqa: C901
self,
inspector: Inspector,
schema: str,
sql_config: SQLCommonConfig,
) -> Iterable[Union[SqlWorkUnit, MetadataWorkUnit]]:
tables_seen: Set[str] = set()
data_reader = self.init_data_reader(inspector)
try:
for table in inspector.get_table_names(schema):
dataset_name = self.get_identifier(
Expand All @@ -669,12 +693,14 @@ def loop_tables( # noqa: C901

try:
yield from self._process_table(
dataset_name, inspector, schema, table, sql_config
dataset_name, inspector, schema, table, sql_config, data_reader
)
except Exception as e:
self.warn(logger, f"{schema}.{table}", f"Ingestion error: {e}")
except Exception as e:
self.error(logger, f"{schema}", f"Tables error: {e}")
if data_reader:
data_reader.close()

def add_information_for_schema(self, inspector: Inspector, schema: str) -> None:
pass
Expand All @@ -691,6 +717,7 @@ def _process_table(
schema: str,
table: str,
sql_config: SQLCommonConfig,
data_reader: Optional[DataReader],
) -> Iterable[Union[SqlWorkUnit, MetadataWorkUnit]]:
columns = self._get_columns(dataset_name, inspector, schema, table)
dataset_urn = make_dataset_urn_with_platform_instance(
Expand Down Expand Up @@ -740,6 +767,8 @@ def _process_table(
foreign_keys,
schema_fields,
)
self._classify(dataset_name, schema, table, data_reader, schema_metadata)

dataset_snapshot.aspects.append(schema_metadata)
if self.config.include_view_lineage:
self.schema_resolver.add_schema_metadata(dataset_urn, schema_metadata)
Expand Down Expand Up @@ -770,6 +799,35 @@ def _process_table(
domain_registry=self.domain_registry,
)

def _classify(self, dataset_name, schema, table, data_reader, schema_metadata):
try:
if (
self.classification_handler.is_classification_enabled_for_table(
dataset_name
)
and data_reader
):
self.classification_handler.classify_schema_fields(
dataset_name,
schema_metadata,
data_reader.get_data_for_table(
table_id=[schema, table],
sample_size=self.config.classification.sample_size,
),
)
except Exception as e:
import pdb

pdb.set_trace()
logger.debug(
f"Failed to classify table columns for {dataset_name} due to error -> {e}",
exc_info=e,
)
self.report.report_warning(
"Failed to classify table columns",
dataset_name,
)

def get_database_properties(
self, inspector: Inspector, database: str
) -> Optional[Dict[str, str]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
LowerCaseDatasetUrnConfigMixin,
)
from datahub.configuration.validate_field_removal import pydantic_removed_field
from datahub.ingestion.glossary.classification_mixin import (
ClassificationSourceConfigMixin,
)
from datahub.ingestion.source.ge_profiling_config import GEProfilingConfig
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StatefulStaleMetadataRemovalConfig,
Expand All @@ -29,6 +32,7 @@ class SQLCommonConfig(
DatasetSourceConfigMixin,
LowerCaseDatasetUrnConfigMixin,
LineageConfig,
ClassificationSourceConfigMixin,
):
options: dict = pydantic.Field(
default_factory=dict,
Expand Down
4 changes: 3 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/sql/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.extractor import schema_util
from datahub.ingestion.source.sql.data_reader import DataReader
from datahub.ingestion.source.sql.sql_common import (
SQLAlchemySource,
SqlWorkUnit,
Expand Down Expand Up @@ -334,9 +335,10 @@ def _process_table(
schema: str,
table: str,
sql_config: SQLCommonConfig,
data_reader: Optional[DataReader],
) -> Iterable[Union[SqlWorkUnit, MetadataWorkUnit]]:
yield from super()._process_table(
dataset_name, inspector, schema, table, sql_config
dataset_name, inspector, schema, table, sql_config, data_reader
)
if self.config.ingest_lineage_to_connectors:
dataset_urn = make_dataset_urn_with_platform_instance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
support_status,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.sql.data_reader import DataReader
from datahub.ingestion.source.sql.sql_common import (
SQLAlchemySource,
SQLSourceReport,
Expand Down Expand Up @@ -221,6 +222,7 @@ def _process_table(
schema: str,
table: str,
sql_config: SQLCommonConfig,
data_reader: Optional[DataReader],
) -> Iterable[Union[SqlWorkUnit, MetadataWorkUnit]]:
dataset_urn = make_dataset_urn_with_platform_instance(
self.platform,
Expand All @@ -235,7 +237,7 @@ def _process_table(
owner_urn=f"urn:li:corpuser:{table_owner}",
)
yield from super()._process_table(
dataset_name, inspector, schema, table, sql_config
dataset_name, inspector, schema, table, sql_config, data_reader
)

def loop_views(
Expand Down
Loading

0 comments on commit 764ea8a

Please sign in to comment.