diff --git a/metadata-ingestion/src/datahub/entrypoints.py b/metadata-ingestion/src/datahub/entrypoints.py index 3ee9ecc277596..e0aac99303c1e 100644 --- a/metadata-ingestion/src/datahub/entrypoints.py +++ b/metadata-ingestion/src/datahub/entrypoints.py @@ -47,7 +47,6 @@ def datahub(debug: bool) -> None: logging.getLogger("datahub").setLevel(logging.INFO) # loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] # print(loggers) - # breakpoint() @datahub.command() diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_usage.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_usage.py index 9fe158055ef09..3d01e23a1c668 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_usage.py @@ -200,8 +200,6 @@ def from_entry(cls, entry: AuditLogEntry) -> "QueryEvent": referencedTables = [ BigQueryTableRef.from_spec_obj(spec) for spec in rawRefTables ] - # if job['jobConfiguration']['query']['statementType'] != "SCRIPT" and not referencedTables: - # breakpoint() queryEvent = QueryEvent( timestamp=entry.timestamp, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake.py index 0090828da46f1..112d6a06edcf1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake.py @@ -1,9 +1,13 @@ import logging -from typing import Optional +from typing import Iterable, Optional # This import verifies that the dependencies are available. import snowflake.sqlalchemy # noqa: F401 from snowflake.sqlalchemy import custom_types +from sqlalchemy import create_engine, inspect +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.sql import text +from sqlalchemy.sql.elements import quoted_name from datahub.configuration.common import ConfigModel @@ -18,6 +22,7 @@ register_custom_type(custom_types.TIMESTAMP_TZ, TimeTypeClass) register_custom_type(custom_types.TIMESTAMP_LTZ, TimeTypeClass) register_custom_type(custom_types.TIMESTAMP_NTZ, TimeTypeClass) +register_custom_type(custom_types.VARIANT) logger: logging.Logger = logging.getLogger(__name__) @@ -53,7 +58,23 @@ def get_sql_alchemy_url(self, database=None): class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig): - database: str + database_pattern: AllowDenyPattern = AllowDenyPattern( + deny=[ + r"^UTIL_DB$", + r"^SNOWFLAKE$", + r"^SNOWFLAKE_SAMPLE_DATA$", + ] + ) + + database: str = ".*" # deprecated + + @pydantic.validator("database") + def note_database_opt_deprecation(cls, v, values, **kwargs): + logger.warn( + "snowflake's `database` option has been deprecated; use database_pattern instead" + ) + values["database_pattern"].allow = f"^{v}$" + return None def get_sql_alchemy_url(self): return super().get_sql_alchemy_url(self.database) @@ -64,6 +85,8 @@ def get_identifier(self, schema: str, table: str) -> str: class SnowflakeSource(SQLAlchemySource): + config: SnowflakeConfig + def __init__(self, config, ctx): super().__init__(config, ctx, "snowflake") @@ -71,3 +94,22 @@ def __init__(self, config, ctx): def create(cls, config_dict, ctx): config = SnowflakeConfig.parse_obj(config_dict) return cls(config, ctx) + + def get_inspectors(self) -> Iterable[Inspector]: + url = self.config.get_sql_alchemy_url() + logger.debug(f"sql_alchemy_url={url}") + engine = create_engine(url, **self.config.options) + + for db_row in engine.execute(text("SHOW DATABASES")): + with engine.connect() as conn: + db = db_row.name + if self.config.database_pattern.allowed(db): + # TRICKY: As we iterate through this loop, we modify the value of + # self.config.database so that the get_identifier method can function + # as intended. + self.config.database = db + conn.execute((f'USE DATABASE "{quoted_name(db, True)}"')) + inspector = inspect(conn) + yield inspector + else: + self.report.report_dropped(db) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql_common.py index 48632dcbdbf6c..823bb72407e49 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql_common.py @@ -238,26 +238,35 @@ def __init__(self, config: SQLAlchemyConfig, ctx: PipelineContext, platform: str self.platform = platform self.report = SQLSourceReport() + def get_inspectors(self) -> Iterable[Inspector]: + # This method can be overridden in the case that you want to dynamically + # run on multiple databases. + + url = self.config.get_sql_alchemy_url() + logger.debug(f"sql_alchemy_url={url}") + engine = create_engine(url, **self.config.options) + inspector = inspect(engine) + yield inspector + def get_workunits(self) -> Iterable[SqlWorkUnit]: sql_config = self.config if logger.isEnabledFor(logging.DEBUG): # If debug logging is enabled, we also want to echo each SQL query issued. sql_config.options["echo"] = True - url = sql_config.get_sql_alchemy_url() - logger.debug(f"sql_alchemy_url={url}") - engine = create_engine(url, **sql_config.options) - inspector = inspect(engine) - for schema in inspector.get_schema_names(): - if not sql_config.schema_pattern.allowed(schema): - self.report.report_dropped(schema) - continue + for inspector in self.get_inspectors(): + for schema in inspector.get_schema_names(): + if not sql_config.schema_pattern.allowed(schema): + self.report.report_dropped( + ".".join(sql_config.standardize_schema_table_names(schema, "*")) + ) + continue - if sql_config.include_tables: - yield from self.loop_tables(inspector, schema, sql_config) + if sql_config.include_tables: + yield from self.loop_tables(inspector, schema, sql_config) - if sql_config.include_views: - yield from self.loop_views(inspector, schema, sql_config) + if sql_config.include_views: + yield from self.loop_views(inspector, schema, sql_config) def loop_tables( self, diff --git a/metadata-ingestion/tests/unit/test_snowflake_source.py b/metadata-ingestion/tests/unit/test_snowflake_source.py index 7ed1c1e9bd9a0..47e1a6fef8996 100644 --- a/metadata-ingestion/tests/unit/test_snowflake_source.py +++ b/metadata-ingestion/tests/unit/test_snowflake_source.py @@ -15,5 +15,5 @@ def test_snowflake_uri(): assert ( config.get_sql_alchemy_url() - == "snowflake://user:password@acctname/demo?warehouse=COMPUTE_WH&role=sysadmin" + == "snowflake://user:password@acctname/?warehouse=COMPUTE_WH&role=sysadmin" )