diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py index a9685b2554553..069c1f2781460 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py @@ -14,7 +14,12 @@ platform_name, support_status, ) -from datahub.ingestion.api.source import SourceCapability +from datahub.ingestion.api.source import ( + CapabilityReport, + SourceCapability, + TestableSource, + TestConnectionReport, +) from datahub.ingestion.source.dbt.dbt_common import ( DBTColumn, DBTCommonConfig, @@ -177,7 +182,7 @@ class DBTCloudConfig(DBTCommonConfig): @support_status(SupportStatus.INCUBATING) @capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion") @capability(SourceCapability.LINEAGE_COARSE, "Enabled by default") -class DBTCloudSource(DBTSourceBase): +class DBTCloudSource(DBTSourceBase, TestableSource): """ This source pulls dbt metadata directly from the dbt Cloud APIs. @@ -199,6 +204,57 @@ def create(cls, config_dict, ctx): config = DBTCloudConfig.parse_obj(config_dict) return cls(config, ctx, "dbt") + @staticmethod + def test_connection(config_dict: dict) -> TestConnectionReport: + test_report = TestConnectionReport() + try: + source_config = DBTCloudConfig.parse_obj_allow_extras(config_dict) + DBTCloudSource._send_graphql_query( + metadata_endpoint=source_config.metadata_endpoint, + token=source_config.token, + query=_DBT_GRAPHQL_QUERY.format(type="tests", fields="jobId"), + variables={ + "jobId": source_config.job_id, + "runId": source_config.run_id, + }, + ) + test_report.basic_connectivity = CapabilityReport(capable=True) + except Exception as e: + test_report.basic_connectivity = CapabilityReport( + capable=False, failure_reason=str(e) + ) + return test_report + + @staticmethod + def _send_graphql_query( + metadata_endpoint: str, token: str, query: str, variables: Dict + ) -> Dict: + logger.debug(f"Sending GraphQL query to dbt Cloud: {query}") + response = requests.post( + metadata_endpoint, + json={ + "query": query, + "variables": variables, + }, + headers={ + "Authorization": f"Bearer {token}", + "X-dbt-partner-source": "acryldatahub", + }, + ) + + try: + res = response.json() + if "errors" in res: + raise ValueError( + f'Unable to fetch metadata from dbt Cloud: {res["errors"]}' + ) + data = res["data"] + except JSONDecodeError as e: + response.raise_for_status() + raise e + + return data + def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]: # TODO: In dbt Cloud, commands are scheduled as part of jobs, where # each job can have multiple runs. We currently only fully support @@ -213,6 +269,8 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]: for node_type, fields in _DBT_FIELDS_BY_TYPE.items(): logger.info(f"Fetching {node_type} from dbt Cloud") data = self._send_graphql_query( + metadata_endpoint=self.config.metadata_endpoint, + token=self.config.token, query=_DBT_GRAPHQL_QUERY.format(type=node_type, fields=fields), variables={ "jobId": self.config.job_id, @@ -232,33 +290,6 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]: return nodes, additional_metadata - def _send_graphql_query(self, query: str, variables: Dict) -> Dict: - logger.debug(f"Sending GraphQL query to dbt Cloud: {query}") - response = requests.post( - self.config.metadata_endpoint, - json={ - "query": query, - "variables": variables, - }, - headers={ - "Authorization": f"Bearer {self.config.token}", - "X-dbt-partner-source": "acryldatahub", - }, - ) - - try: - res = response.json() - if "errors" in res: - raise ValueError( - f'Unable to fetch metadata from dbt Cloud: {res["errors"]}' - ) - data = res["data"] - except JSONDecodeError as e: - response.raise_for_status() - raise e - - return data - def _parse_into_dbt_node(self, node: Dict) -> DBTNode: key = node["uniqueId"] diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py index ac2b2815f3caa..563b005d7a88d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py @@ -18,7 +18,12 @@ platform_name, support_status, ) -from datahub.ingestion.api.source import SourceCapability +from datahub.ingestion.api.source import ( + CapabilityReport, + SourceCapability, + TestableSource, + TestConnectionReport, +) from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig from datahub.ingestion.source.dbt.dbt_common import ( DBTColumn, @@ -60,11 +65,6 @@ class DBTCoreConfig(DBTCommonConfig): _github_info_deprecated = pydantic_renamed_field("github_info", "git_info") - @property - def s3_client(self): - assert self.aws_connection - return self.aws_connection.get_s3_client() - @validator("aws_connection") def aws_connection_needed_if_s3_uris_present( cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any @@ -363,7 +363,7 @@ def load_test_results( @support_status(SupportStatus.CERTIFIED) @capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion") @capability(SourceCapability.LINEAGE_COARSE, "Enabled by default") -class DBTCoreSource(DBTSourceBase): +class DBTCoreSource(DBTSourceBase, TestableSource): """ The artifacts used by this source are: - [dbt manifest file](https://docs.getdbt.com/reference/artifacts/manifest-json) @@ -387,12 +387,34 @@ def create(cls, config_dict, ctx): config = DBTCoreConfig.parse_obj(config_dict) return cls(config, ctx, "dbt") - def load_file_as_json(self, uri: str) -> Any: + @staticmethod + def test_connection(config_dict: dict) -> TestConnectionReport: + test_report = TestConnectionReport() + try: + source_config = DBTCoreConfig.parse_obj_allow_extras(config_dict) + DBTCoreSource.load_file_as_json( + source_config.manifest_path, source_config.aws_connection + ) + DBTCoreSource.load_file_as_json( + source_config.catalog_path, source_config.aws_connection + ) + test_report.basic_connectivity = CapabilityReport(capable=True) + except Exception as e: + test_report.basic_connectivity = CapabilityReport( + capable=False, failure_reason=str(e) + ) + return test_report + + @staticmethod + def load_file_as_json( + uri: str, aws_connection: Optional[AwsConnectionConfig] + ) -> Dict: if re.match("^https?://", uri): return json.loads(requests.get(uri).text) elif re.match("^s3://", uri): u = urlparse(uri) - response = self.config.s3_client.get_object( + assert aws_connection + response = aws_connection.get_s3_client().get_object( Bucket=u.netloc, Key=u.path.lstrip("/") ) return json.loads(response["Body"].read().decode("utf-8")) @@ -410,12 +432,18 @@ def loadManifestAndCatalog( Optional[str], Optional[str], ]: - dbt_manifest_json = self.load_file_as_json(self.config.manifest_path) + dbt_manifest_json = self.load_file_as_json( + self.config.manifest_path, self.config.aws_connection + ) - dbt_catalog_json = self.load_file_as_json(self.config.catalog_path) + dbt_catalog_json = self.load_file_as_json( + self.config.catalog_path, self.config.aws_connection + ) if self.config.sources_path is not None: - dbt_sources_json = self.load_file_as_json(self.config.sources_path) + dbt_sources_json = self.load_file_as_json( + self.config.sources_path, self.config.aws_connection + ) sources_results = dbt_sources_json["results"] else: sources_results = {} @@ -491,7 +519,9 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]: # This will populate the test_results field on each test node. all_nodes = load_test_results( self.config, - self.load_file_as_json(self.config.test_results_path), + self.load_file_as_json( + self.config.test_results_path, self.config.aws_connection + ), all_nodes, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka.py b/metadata-ingestion/src/datahub/ingestion/source/kafka.py index 25520e7aa66ff..99ef737206ab0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka.py @@ -15,6 +15,7 @@ ConfigResource, TopicMetadata, ) +from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient from datahub.configuration.common import AllowDenyPattern from datahub.configuration.kafka import KafkaConsumerConnectionConfig @@ -40,7 +41,13 @@ support_status, ) from datahub.ingestion.api.registry import import_path -from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceCapability +from datahub.ingestion.api.source import ( + CapabilityReport, + MetadataWorkUnitProcessor, + SourceCapability, + TestableSource, + TestConnectionReport, +) from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.common.subtypes import DatasetSubTypes from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase @@ -133,6 +140,18 @@ class KafkaSourceConfig( ) +def get_kafka_consumer( + connection: KafkaConsumerConnectionConfig, +) -> confluent_kafka.Consumer: + return confluent_kafka.Consumer( + { + "group.id": "test", + "bootstrap.servers": connection.bootstrap, + **connection.consumer_config, + } + ) + + @dataclass class KafkaSourceReport(StaleEntityRemovalSourceReport): topics_scanned: int = 0 @@ -145,6 +164,45 @@ def report_dropped(self, topic: str) -> None: self.filtered.append(topic) +class KafkaConnectionTest: + def __init__(self, config_dict: dict): + self.config = KafkaSourceConfig.parse_obj_allow_extras(config_dict) + self.report = KafkaSourceReport() + self.consumer: confluent_kafka.Consumer = get_kafka_consumer( + self.config.connection + ) + + def get_connection_test(self) -> TestConnectionReport: + capability_report = { + SourceCapability.SCHEMA_METADATA: self.schema_registry_connectivity(), + } + return TestConnectionReport( + basic_connectivity=self.basic_connectivity(), + capability_report={ + k: v for k, v in capability_report.items() if v is not None + }, + ) + + def basic_connectivity(self) -> CapabilityReport: + try: + self.consumer.list_topics(timeout=10) + return CapabilityReport(capable=True) + except Exception as e: + return CapabilityReport(capable=False, failure_reason=str(e)) + + def schema_registry_connectivity(self) -> CapabilityReport: + try: + SchemaRegistryClient( + { + "url": self.config.connection.schema_registry_url, + **self.config.connection.schema_registry_config, + } + ).get_subjects() + return CapabilityReport(capable=True) + except Exception as e: + return CapabilityReport(capable=False, failure_reason=str(e)) + + @platform_name("Kafka") @config_class(KafkaSourceConfig) @support_status(SupportStatus.CERTIFIED) @@ -160,7 +218,7 @@ def report_dropped(self, topic: str) -> None: SourceCapability.SCHEMA_METADATA, "Schemas associated with each topic are extracted from the schema registry. Avro and Protobuf (certified), JSON (incubating). Schema references are supported.", ) -class KafkaSource(StatefulIngestionSourceBase): +class KafkaSource(StatefulIngestionSourceBase, TestableSource): """ This plugin extracts the following: - Topics from the Kafka broker @@ -183,12 +241,8 @@ def create_schema_registry( def __init__(self, config: KafkaSourceConfig, ctx: PipelineContext): super().__init__(config, ctx) self.source_config: KafkaSourceConfig = config - self.consumer: confluent_kafka.Consumer = confluent_kafka.Consumer( - { - "group.id": "test", - "bootstrap.servers": self.source_config.connection.bootstrap, - **self.source_config.connection.consumer_config, - } + self.consumer: confluent_kafka.Consumer = get_kafka_consumer( + self.source_config.connection ) self.init_kafka_admin_client() self.report: KafkaSourceReport = KafkaSourceReport() @@ -226,6 +280,10 @@ def init_kafka_admin_client(self) -> None: f"Failed to create Kafka Admin Client due to error {e}.", ) + @staticmethod + def test_connection(config_dict: dict) -> TestConnectionReport: + return KafkaConnectionTest(config_dict).get_connection_test() + @classmethod def create(cls, config_dict: Dict, ctx: PipelineContext) -> "KafkaSource": config: KafkaSourceConfig = KafkaSourceConfig.parse_obj(config_dict) diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py index 4b1d0403ac776..cdf7c975c0614 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py @@ -19,7 +19,13 @@ platform_name, support_status, ) -from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceReport +from datahub.ingestion.api.source import ( + CapabilityReport, + MetadataWorkUnitProcessor, + SourceReport, + TestableSource, + TestConnectionReport, +) from datahub.ingestion.api.source_helpers import auto_workunit from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.common.subtypes import ( @@ -1147,7 +1153,7 @@ def report_to_datahub_work_units( SourceCapability.LINEAGE_FINE, "Disabled by default, configured using `extract_column_level_lineage`. ", ) -class PowerBiDashboardSource(StatefulIngestionSourceBase): +class PowerBiDashboardSource(StatefulIngestionSourceBase, TestableSource): """ This plugin extracts the following: - Power BI dashboards, tiles and datasets @@ -1186,6 +1192,18 @@ def __init__(self, config: PowerBiDashboardSourceConfig, ctx: PipelineContext): self, self.source_config, self.ctx ) + @staticmethod + def test_connection(config_dict: dict) -> TestConnectionReport: + test_report = TestConnectionReport() + try: + PowerBiAPI(PowerBiDashboardSourceConfig.parse_obj_allow_extras(config_dict)) + test_report.basic_connectivity = CapabilityReport(capable=True) + except Exception as e: + test_report.basic_connectivity = CapabilityReport( + capable=False, failure_reason=str(e) + ) + return test_report + @classmethod def create(cls, config_dict, ctx): config = PowerBiDashboardSourceConfig.parse_obj(config_dict) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py index 590bc7f696784..a831dfa50342d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py @@ -15,6 +15,7 @@ Tuple, Type, Union, + cast, ) import sqlalchemy.dialects.postgresql.base @@ -35,7 +36,12 @@ from datahub.emitter.sql_parsing_builder import SqlParsingBuilder from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.incremental_lineage_helper import auto_incremental_lineage -from datahub.ingestion.api.source import MetadataWorkUnitProcessor +from datahub.ingestion.api.source import ( + CapabilityReport, + MetadataWorkUnitProcessor, + TestableSource, + TestConnectionReport, +) from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.common.subtypes import ( DatasetContainerSubTypes, @@ -298,7 +304,7 @@ class ProfileMetadata: dataset_name_to_storage_bytes: Dict[str, int] = field(default_factory=dict) -class SQLAlchemySource(StatefulIngestionSourceBase): +class SQLAlchemySource(StatefulIngestionSourceBase, TestableSource): """A Base class for all SQL Sources that use SQLAlchemy to extend""" def __init__(self, config: SQLCommonConfig, ctx: PipelineContext, platform: str): @@ -348,6 +354,22 @@ def __init__(self, config: SQLCommonConfig, ctx: PipelineContext, platform: str) else: self._view_definition_cache = {} + @classmethod + def test_connection(cls, config_dict: dict) -> TestConnectionReport: + test_report = TestConnectionReport() + try: + source = cast( + SQLAlchemySource, + cls.create(config_dict, PipelineContext(run_id="test_connection")), + ) + list(source.get_inspectors()) + test_report.basic_connectivity = CapabilityReport(capable=True) + except Exception as e: + test_report.basic_connectivity = CapabilityReport( + capable=False, failure_reason=str(e) + ) + return test_report + def warn(self, log: logging.Logger, key: str, reason: str) -> None: self.report.report_warning(key, reason[:100]) log.warning(f"{key} => {reason}") diff --git a/metadata-ingestion/src/datahub/ingestion/source/tableau.py b/metadata-ingestion/src/datahub/ingestion/source/tableau.py index da44d09121c6c..89fb351317b6d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/tableau.py +++ b/metadata-ingestion/src/datahub/ingestion/source/tableau.py @@ -57,7 +57,13 @@ platform_name, support_status, ) -from datahub.ingestion.api.source import MetadataWorkUnitProcessor, Source +from datahub.ingestion.api.source import ( + CapabilityReport, + MetadataWorkUnitProcessor, + Source, + TestableSource, + TestConnectionReport, +) from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source import tableau_constant as c from datahub.ingestion.source.common.subtypes import ( @@ -456,7 +462,7 @@ class TableauSourceReport(StaleEntityRemovalSourceReport): SourceCapability.LINEAGE_FINE, "Enabled by default, configure using `extract_column_level_lineage`", ) -class TableauSource(StatefulIngestionSourceBase): +class TableauSource(StatefulIngestionSourceBase, TestableSource): platform = "tableau" def __hash__(self): @@ -496,6 +502,19 @@ def __init__( self._authenticate() + @staticmethod + def test_connection(config_dict: dict) -> TestConnectionReport: + test_report = TestConnectionReport() + try: + source_config = TableauConfig.parse_obj_allow_extras(config_dict) + source_config.make_tableau_client() + test_report.basic_connectivity = CapabilityReport(capable=True) + except Exception as e: + test_report.basic_connectivity = CapabilityReport( + capable=False, failure_reason=str(e) + ) + return test_report + def close(self) -> None: try: if self.server is not None: diff --git a/metadata-ingestion/src/datahub/ingestion/source_config/sql/snowflake.py b/metadata-ingestion/src/datahub/ingestion/source_config/sql/snowflake.py index ccc4e115729a2..46bd24c7e1f4c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source_config/sql/snowflake.py +++ b/metadata-ingestion/src/datahub/ingestion/source_config/sql/snowflake.py @@ -143,7 +143,7 @@ def _check_oauth_config(oauth_config: Optional[OAuthConfiguration]) -> None: "'oauth_config' is none but should be set when using OAUTH_AUTHENTICATOR authentication" ) if oauth_config.use_certificate is True: - if oauth_config.provider == OAuthIdentityProvider.OKTA.value: + if oauth_config.provider == OAuthIdentityProvider.OKTA: raise ValueError( "Certificate authentication is not supported for Okta." ) diff --git a/metadata-ingestion/tests/integration/dbt/test_dbt.py b/metadata-ingestion/tests/integration/dbt/test_dbt.py index 95b5374bbb41d..587831495c1ea 100644 --- a/metadata-ingestion/tests/integration/dbt/test_dbt.py +++ b/metadata-ingestion/tests/integration/dbt/test_dbt.py @@ -10,20 +10,25 @@ from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig from datahub.ingestion.source.dbt.dbt_common import DBTEntitiesEnabled, EmitDirective -from datahub.ingestion.source.dbt.dbt_core import DBTCoreConfig +from datahub.ingestion.source.dbt.dbt_core import DBTCoreConfig, DBTCoreSource from datahub.ingestion.source.sql.sql_types import ( ATHENA_SQL_TYPES_MAP, TRINO_SQL_TYPES_MAP, resolve_athena_modified_type, resolve_trino_modified_type, ) -from tests.test_helpers import mce_helpers +from tests.test_helpers import mce_helpers, test_connection_helpers FROZEN_TIME = "2022-02-03 07:00:00" GMS_PORT = 8080 GMS_SERVER = f"http://localhost:{GMS_PORT}" +@pytest.fixture(scope="module") +def test_resources_dir(pytestconfig): + return pytestconfig.rootpath / "tests/integration/dbt" + + @dataclass class DbtTestConfig: run_id: str @@ -195,7 +200,14 @@ def set_paths( ) @pytest.mark.integration @freeze_time(FROZEN_TIME) -def test_dbt_ingest(dbt_test_config, pytestconfig, tmp_path, mock_time, requests_mock): +def test_dbt_ingest( + dbt_test_config, + test_resources_dir, + pytestconfig, + tmp_path, + mock_time, + requests_mock, +): config: DbtTestConfig = dbt_test_config test_resources_dir = pytestconfig.rootpath / "tests/integration/dbt" @@ -233,11 +245,48 @@ def test_dbt_ingest(dbt_test_config, pytestconfig, tmp_path, mock_time, requests ) +@pytest.mark.parametrize( + "config_dict, is_success", + [ + ( + { + "manifest_path": "dbt_manifest.json", + "catalog_path": "dbt_catalog.json", + "target_platform": "postgres", + }, + True, + ), + ( + { + "manifest_path": "dbt_manifest.json", + "catalog_path": "dbt_catalog-this-file-does-not-exist.json", + "target_platform": "postgres", + }, + False, + ), + ], +) @pytest.mark.integration @freeze_time(FROZEN_TIME) -def test_dbt_tests(pytestconfig, tmp_path, mock_time, **kwargs): - test_resources_dir = pytestconfig.rootpath / "tests/integration/dbt" +def test_dbt_test_connection(test_resources_dir, config_dict, is_success): + config_dict["manifest_path"] = str( + (test_resources_dir / config_dict["manifest_path"]).resolve() + ) + config_dict["catalog_path"] = str( + (test_resources_dir / config_dict["catalog_path"]).resolve() + ) + report = test_connection_helpers.run_test_connection(DBTCoreSource, config_dict) + if is_success: + test_connection_helpers.assert_basic_connectivity_success(report) + else: + test_connection_helpers.assert_basic_connectivity_failure( + report, "No such file or directory" + ) + +@pytest.mark.integration +@freeze_time(FROZEN_TIME) +def test_dbt_tests(test_resources_dir, pytestconfig, tmp_path, mock_time, **kwargs): # Run the metadata ingestion pipeline. output_file = tmp_path / "dbt_test_events.json" golden_path = test_resources_dir / "dbt_test_events_golden.json" @@ -340,9 +389,9 @@ def test_resolve_athena_modified_type(data_type, expected_data_type): @pytest.mark.integration @freeze_time(FROZEN_TIME) -def test_dbt_tests_only_assertions(pytestconfig, tmp_path, mock_time, **kwargs): - test_resources_dir = pytestconfig.rootpath / "tests/integration/dbt" - +def test_dbt_tests_only_assertions( + test_resources_dir, pytestconfig, tmp_path, mock_time, **kwargs +): # Run the metadata ingestion pipeline. output_file = tmp_path / "test_only_assertions.json" @@ -418,10 +467,8 @@ def test_dbt_tests_only_assertions(pytestconfig, tmp_path, mock_time, **kwargs): @pytest.mark.integration @freeze_time(FROZEN_TIME) def test_dbt_only_test_definitions_and_results( - pytestconfig, tmp_path, mock_time, **kwargs + test_resources_dir, pytestconfig, tmp_path, mock_time, **kwargs ): - test_resources_dir = pytestconfig.rootpath / "tests/integration/dbt" - # Run the metadata ingestion pipeline. output_file = tmp_path / "test_only_definitions_and_assertions.json" diff --git a/metadata-ingestion/tests/integration/kafka/test_kafka.py b/metadata-ingestion/tests/integration/kafka/test_kafka.py index 63d284801c94c..dfdbea5de5cbf 100644 --- a/metadata-ingestion/tests/integration/kafka/test_kafka.py +++ b/metadata-ingestion/tests/integration/kafka/test_kafka.py @@ -3,18 +3,22 @@ import pytest from freezegun import freeze_time -from tests.test_helpers import mce_helpers +from datahub.ingestion.api.source import SourceCapability +from datahub.ingestion.source.kafka import KafkaSource +from tests.test_helpers import mce_helpers, test_connection_helpers from tests.test_helpers.click_helpers import run_datahub_cmd from tests.test_helpers.docker_helpers import wait_for_port FROZEN_TIME = "2020-04-14 07:00:00" -@freeze_time(FROZEN_TIME) -@pytest.mark.integration -def test_kafka_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time): - test_resources_dir = pytestconfig.rootpath / "tests/integration/kafka" +@pytest.fixture(scope="module") +def test_resources_dir(pytestconfig): + return pytestconfig.rootpath / "tests/integration/kafka" + +@pytest.fixture(scope="module") +def mock_kafka_service(docker_compose_runner, test_resources_dir): with docker_compose_runner( test_resources_dir / "docker-compose.yml", "kafka", cleanup=False ) as docker_services: @@ -31,14 +35,67 @@ def test_kafka_ingest(docker_compose_runner, pytestconfig, tmp_path, mock_time): command = f"{test_resources_dir}/send_records.sh {test_resources_dir}" subprocess.run(command, shell=True, check=True) - # Run the metadata ingestion pipeline. - config_file = (test_resources_dir / "kafka_to_file.yml").resolve() - run_datahub_cmd(["ingest", "-c", f"{config_file}"], tmp_path=tmp_path) + yield docker_compose_runner + + +@freeze_time(FROZEN_TIME) +@pytest.mark.integration +def test_kafka_ingest( + mock_kafka_service, test_resources_dir, pytestconfig, tmp_path, mock_time +): + # Run the metadata ingestion pipeline. + config_file = (test_resources_dir / "kafka_to_file.yml").resolve() + run_datahub_cmd(["ingest", "-c", f"{config_file}"], tmp_path=tmp_path) - # Verify the output. - mce_helpers.check_golden_file( - pytestconfig, - output_path=tmp_path / "kafka_mces.json", - golden_path=test_resources_dir / "kafka_mces_golden.json", - ignore_paths=[], + # Verify the output. + mce_helpers.check_golden_file( + pytestconfig, + output_path=tmp_path / "kafka_mces.json", + golden_path=test_resources_dir / "kafka_mces_golden.json", + ignore_paths=[], + ) + + +@pytest.mark.parametrize( + "config_dict, is_success", + [ + ( + { + "connection": { + "bootstrap": "localhost:29092", + "schema_registry_url": "http://localhost:28081", + }, + }, + True, + ), + ( + { + "connection": { + "bootstrap": "localhost:2909", + "schema_registry_url": "http://localhost:2808", + }, + }, + False, + ), + ], +) +@pytest.mark.integration +@freeze_time(FROZEN_TIME) +def test_kafka_test_connection(mock_kafka_service, config_dict, is_success): + report = test_connection_helpers.run_test_connection(KafkaSource, config_dict) + if is_success: + test_connection_helpers.assert_basic_connectivity_success(report) + test_connection_helpers.assert_capability_report( + capability_report=report.capability_report, + success_capabilities=[SourceCapability.SCHEMA_METADATA], + ) + else: + test_connection_helpers.assert_basic_connectivity_failure( + report, "Failed to get metadata" + ) + test_connection_helpers.assert_capability_report( + capability_report=report.capability_report, + failure_capabilities={ + SourceCapability.SCHEMA_METADATA: "Failed to establish a new connection" + }, ) diff --git a/metadata-ingestion/tests/integration/mysql/test_mysql.py b/metadata-ingestion/tests/integration/mysql/test_mysql.py index 23fd97ff2671e..c19198c7d2bbd 100644 --- a/metadata-ingestion/tests/integration/mysql/test_mysql.py +++ b/metadata-ingestion/tests/integration/mysql/test_mysql.py @@ -3,7 +3,8 @@ import pytest from freezegun import freeze_time -from tests.test_helpers import mce_helpers +from datahub.ingestion.source.sql.mysql import MySQLSource +from tests.test_helpers import mce_helpers, test_connection_helpers from tests.test_helpers.click_helpers import run_datahub_cmd from tests.test_helpers.docker_helpers import wait_for_port @@ -75,3 +76,38 @@ def test_mysql_ingest_no_db( output_path=tmp_path / "mysql_mces.json", golden_path=test_resources_dir / golden_file, ) + + +@pytest.mark.parametrize( + "config_dict, is_success", + [ + ( + { + "host_port": "localhost:53307", + "database": "northwind", + "username": "root", + "password": "example", + }, + True, + ), + ( + { + "host_port": "localhost:5330", + "database": "wrong_db", + "username": "wrong_user", + "password": "wrong_pass", + }, + False, + ), + ], +) +@freeze_time(FROZEN_TIME) +@pytest.mark.integration +def test_mysql_test_connection(mysql_runner, config_dict, is_success): + report = test_connection_helpers.run_test_connection(MySQLSource, config_dict) + if is_success: + test_connection_helpers.assert_basic_connectivity_success(report) + else: + test_connection_helpers.assert_basic_connectivity_failure( + report, "Connection refused" + ) diff --git a/metadata-ingestion/tests/integration/powerbi/test_powerbi.py b/metadata-ingestion/tests/integration/powerbi/test_powerbi.py index c9b0ded433749..698047ae0d1c7 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_powerbi.py +++ b/metadata-ingestion/tests/integration/powerbi/test_powerbi.py @@ -19,7 +19,7 @@ Report, Workspace, ) -from tests.test_helpers import mce_helpers +from tests.test_helpers import mce_helpers, test_connection_helpers pytestmark = pytest.mark.integration_batch_2 FROZEN_TIME = "2022-02-03 07:00:00" @@ -663,6 +663,27 @@ def test_powerbi_ingest(mock_msal, pytestconfig, tmp_path, mock_time, requests_m ) +@freeze_time(FROZEN_TIME) +@mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) +@pytest.mark.integration +def test_powerbi_test_connection_success(mock_msal): + report = test_connection_helpers.run_test_connection( + PowerBiDashboardSource, default_source_config() + ) + test_connection_helpers.assert_basic_connectivity_success(report) + + +@freeze_time(FROZEN_TIME) +@pytest.mark.integration +def test_powerbi_test_connection_failure(): + report = test_connection_helpers.run_test_connection( + PowerBiDashboardSource, default_source_config() + ) + test_connection_helpers.assert_basic_connectivity_failure( + report, "Unable to get authority configuration" + ) + + @freeze_time(FROZEN_TIME) @mock.patch("msal.ConfidentialClientApplication", side_effect=mock_msal_cca) @pytest.mark.integration diff --git a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py index 0510f4a40f659..90fa71013338d 100644 --- a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py +++ b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py @@ -28,7 +28,7 @@ ) from datahub.metadata.schema_classes import MetadataChangeProposalClass, UpstreamClass from datahub.utilities.sqlglot_lineage import SqlParsingResult -from tests.test_helpers import mce_helpers +from tests.test_helpers import mce_helpers, test_connection_helpers from tests.test_helpers.state_helpers import ( get_current_checkpoint_from_pipeline, validate_all_providers_have_committed_successfully, @@ -290,6 +290,25 @@ def test_tableau_ingest(pytestconfig, tmp_path, mock_datahub_graph): ) +@freeze_time(FROZEN_TIME) +@pytest.mark.integration +def test_tableau_test_connection_success(): + with mock.patch("datahub.ingestion.source.tableau.Server"): + report = test_connection_helpers.run_test_connection( + TableauSource, config_source_default + ) + test_connection_helpers.assert_basic_connectivity_success(report) + + +@freeze_time(FROZEN_TIME) +@pytest.mark.integration +def test_tableau_test_connection_failure(): + report = test_connection_helpers.run_test_connection( + TableauSource, config_source_default + ) + test_connection_helpers.assert_basic_connectivity_failure(report, "Unable to login") + + @freeze_time(FROZEN_TIME) @pytest.mark.integration def test_tableau_cll_ingest(pytestconfig, tmp_path, mock_datahub_graph): diff --git a/metadata-ingestion/tests/test_helpers/test_connection_helpers.py b/metadata-ingestion/tests/test_helpers/test_connection_helpers.py new file mode 100644 index 0000000000000..45543033ae010 --- /dev/null +++ b/metadata-ingestion/tests/test_helpers/test_connection_helpers.py @@ -0,0 +1,47 @@ +from typing import Dict, List, Optional, Type, Union + +from datahub.ingestion.api.source import ( + CapabilityReport, + SourceCapability, + TestableSource, + TestConnectionReport, +) + + +def run_test_connection( + source_cls: Type[TestableSource], config_dict: Dict +) -> TestConnectionReport: + return source_cls.test_connection(config_dict) + + +def assert_basic_connectivity_success(report: TestConnectionReport) -> None: + assert report is not None + assert report.basic_connectivity + assert report.basic_connectivity.capable + assert report.basic_connectivity.failure_reason is None + + +def assert_basic_connectivity_failure( + report: TestConnectionReport, expected_reason: str +) -> None: + assert report is not None + assert report.basic_connectivity + assert not report.basic_connectivity.capable + assert report.basic_connectivity.failure_reason + assert expected_reason in report.basic_connectivity.failure_reason + + +def assert_capability_report( + capability_report: Optional[Dict[Union[SourceCapability, str], CapabilityReport]], + success_capabilities: List[SourceCapability] = [], + failure_capabilities: Dict[SourceCapability, str] = {}, +) -> None: + assert capability_report + for capability in success_capabilities: + assert capability_report[capability] + assert capability_report[capability].failure_reason is None + for capability, expected_reason in failure_capabilities.items(): + assert not capability_report[capability].capable + failure_reason = capability_report[capability].failure_reason + assert failure_reason + assert expected_reason in failure_reason diff --git a/metadata-ingestion/tests/unit/test_snowflake_source.py b/metadata-ingestion/tests/unit/test_snowflake_source.py index 343f4466fd6fd..536c91ace4f5e 100644 --- a/metadata-ingestion/tests/unit/test_snowflake_source.py +++ b/metadata-ingestion/tests/unit/test_snowflake_source.py @@ -1,3 +1,4 @@ +from typing import Any, Dict from unittest.mock import MagicMock, patch import pytest @@ -24,10 +25,20 @@ SnowflakeObjectAccessEntry, ) from datahub.ingestion.source.snowflake.snowflake_v2 import SnowflakeV2Source +from tests.test_helpers import test_connection_helpers + +default_oauth_dict: Dict[str, Any] = { + "client_id": "client_id", + "client_secret": "secret", + "use_certificate": False, + "provider": "microsoft", + "scopes": ["datahub_role"], + "authority_url": "https://dev-abc.okta.com/oauth2/def/v1/token", +} def test_snowflake_source_throws_error_on_account_id_missing(): - with pytest.raises(ValidationError): + with pytest.raises(ValidationError, match="account_id\n field required"): SnowflakeV2Config.parse_obj( { "username": "user", @@ -37,27 +48,21 @@ def test_snowflake_source_throws_error_on_account_id_missing(): def test_no_client_id_invalid_oauth_config(): - oauth_dict = { - "provider": "microsoft", - "scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"], - "client_secret": "6Hb9apkbc6HD7", - "authority_url": "https://login.microsoftonline.com/yourorganisation.com", - } - with pytest.raises(ValueError): + oauth_dict = default_oauth_dict.copy() + del oauth_dict["client_id"] + with pytest.raises(ValueError, match="client_id\n field required"): OAuthConfiguration.parse_obj(oauth_dict) def test_snowflake_throws_error_on_client_secret_missing_if_use_certificate_is_false(): - oauth_dict = { - "client_id": "882e9831-7ea51cb2b954", - "provider": "microsoft", - "scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"], - "use_certificate": False, - "authority_url": "https://login.microsoftonline.com/yourorganisation.com", - } + oauth_dict = default_oauth_dict.copy() + del oauth_dict["client_secret"] OAuthConfiguration.parse_obj(oauth_dict) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="'oauth_config.client_secret' was none but should be set when using use_certificate false for oauth_config", + ): SnowflakeV2Config.parse_obj( { "account_id": "test", @@ -68,16 +73,13 @@ def test_snowflake_throws_error_on_client_secret_missing_if_use_certificate_is_f def test_snowflake_throws_error_on_encoded_oauth_private_key_missing_if_use_certificate_is_true(): - oauth_dict = { - "client_id": "882e9831-7ea51cb2b954", - "provider": "microsoft", - "scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"], - "use_certificate": True, - "authority_url": "https://login.microsoftonline.com/yourorganisation.com", - "encoded_oauth_public_key": "fkdsfhkshfkjsdfiuwrwfkjhsfskfhksjf==", - } + oauth_dict = default_oauth_dict.copy() + oauth_dict["use_certificate"] = True OAuthConfiguration.parse_obj(oauth_dict) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="'base64_encoded_oauth_private_key' was none but should be set when using certificate for oauth_config", + ): SnowflakeV2Config.parse_obj( { "account_id": "test", @@ -88,16 +90,13 @@ def test_snowflake_throws_error_on_encoded_oauth_private_key_missing_if_use_cert def test_snowflake_oauth_okta_does_not_support_certificate(): - oauth_dict = { - "client_id": "882e9831-7ea51cb2b954", - "provider": "okta", - "scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"], - "use_certificate": True, - "authority_url": "https://login.microsoftonline.com/yourorganisation.com", - "encoded_oauth_public_key": "fkdsfhkshfkjsdfiuwrwfkjhsfskfhksjf==", - } + oauth_dict = default_oauth_dict.copy() + oauth_dict["use_certificate"] = True + oauth_dict["provider"] = "okta" OAuthConfiguration.parse_obj(oauth_dict) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Certificate authentication is not supported for Okta." + ): SnowflakeV2Config.parse_obj( { "account_id": "test", @@ -108,79 +107,52 @@ def test_snowflake_oauth_okta_does_not_support_certificate(): def test_snowflake_oauth_happy_paths(): - okta_dict = { - "client_id": "client_id", - "client_secret": "secret", - "provider": "okta", - "scopes": ["datahub_role"], - "authority_url": "https://dev-abc.okta.com/oauth2/def/v1/token", - } + oauth_dict = default_oauth_dict.copy() + oauth_dict["provider"] = "okta" assert SnowflakeV2Config.parse_obj( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", - "oauth_config": okta_dict, + "oauth_config": oauth_dict, } ) - - microsoft_dict = { - "client_id": "client_id", - "provider": "microsoft", - "scopes": ["https://microsoft.com/f4b353d5-ef8d/.default"], - "use_certificate": True, - "authority_url": "https://login.microsoftonline.com/yourorganisation.com", - "encoded_oauth_public_key": "publickey", - "encoded_oauth_private_key": "privatekey", - } + oauth_dict["use_certificate"] = True + oauth_dict["provider"] = "microsoft" + oauth_dict["encoded_oauth_public_key"] = "publickey" + oauth_dict["encoded_oauth_private_key"] = "privatekey" assert SnowflakeV2Config.parse_obj( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", - "oauth_config": microsoft_dict, + "oauth_config": oauth_dict, } ) +default_config_dict: Dict[str, Any] = { + "username": "user", + "password": "password", + "account_id": "https://acctname.snowflakecomputing.com", + "warehouse": "COMPUTE_WH", + "role": "sysadmin", +} + + def test_account_id_is_added_when_host_port_is_present(): - config = SnowflakeV2Config.parse_obj( - { - "username": "user", - "password": "password", - "host_port": "acctname", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - ) + config_dict = default_config_dict.copy() + del config_dict["account_id"] + config_dict["host_port"] = "acctname" + config = SnowflakeV2Config.parse_obj(config_dict) assert config.account_id == "acctname" def test_account_id_with_snowflake_host_suffix(): - config = SnowflakeV2Config.parse_obj( - { - "username": "user", - "password": "password", - "account_id": "https://acctname.snowflakecomputing.com", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - ) + config = SnowflakeV2Config.parse_obj(default_config_dict) assert config.account_id == "acctname" def test_snowflake_uri_default_authentication(): - config = SnowflakeV2Config.parse_obj( - { - "username": "user", - "password": "password", - "account_id": "acctname", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - ) - + config = SnowflakeV2Config.parse_obj(default_config_dict) assert config.get_sql_alchemy_url() == ( "snowflake://user:password@acctname" "?application=acryl_datahub" @@ -191,17 +163,10 @@ def test_snowflake_uri_default_authentication(): def test_snowflake_uri_external_browser_authentication(): - config = SnowflakeV2Config.parse_obj( - { - "username": "user", - "account_id": "acctname", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - "authentication_type": "EXTERNAL_BROWSER_AUTHENTICATOR", - } - ) - + config_dict = default_config_dict.copy() + del config_dict["password"] + config_dict["authentication_type"] = "EXTERNAL_BROWSER_AUTHENTICATOR" + config = SnowflakeV2Config.parse_obj(config_dict) assert config.get_sql_alchemy_url() == ( "snowflake://user@acctname" "?application=acryl_datahub" @@ -212,18 +177,12 @@ def test_snowflake_uri_external_browser_authentication(): def test_snowflake_uri_key_pair_authentication(): - config = SnowflakeV2Config.parse_obj( - { - "username": "user", - "account_id": "acctname", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - "authentication_type": "KEY_PAIR_AUTHENTICATOR", - "private_key_path": "/a/random/path", - "private_key_password": "a_random_password", - } - ) + config_dict = default_config_dict.copy() + del config_dict["password"] + config_dict["authentication_type"] = "KEY_PAIR_AUTHENTICATOR" + config_dict["private_key_path"] = "/a/random/path" + config_dict["private_key_password"] = "a_random_password" + config = SnowflakeV2Config.parse_obj(config_dict) assert config.get_sql_alchemy_url() == ( "snowflake://user@acctname" @@ -235,63 +194,35 @@ def test_snowflake_uri_key_pair_authentication(): def test_options_contain_connect_args(): - config = SnowflakeV2Config.parse_obj( - { - "username": "user", - "password": "password", - "account_id": "acctname", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - ) + config = SnowflakeV2Config.parse_obj(default_config_dict) connect_args = config.get_options().get("connect_args") assert connect_args is not None def test_snowflake_config_with_view_lineage_no_table_lineage_throws_error(): - with pytest.raises(ValidationError): - SnowflakeV2Config.parse_obj( - { - "username": "user", - "password": "password", - "account_id": "acctname", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - "include_view_lineage": True, - "include_table_lineage": False, - } - ) + config_dict = default_config_dict.copy() + config_dict["include_view_lineage"] = True + config_dict["include_table_lineage"] = False + with pytest.raises( + ValidationError, + match="include_table_lineage must be True for include_view_lineage to be set", + ): + SnowflakeV2Config.parse_obj(config_dict) def test_snowflake_config_with_column_lineage_no_table_lineage_throws_error(): - with pytest.raises(ValidationError): - SnowflakeV2Config.parse_obj( - { - "username": "user", - "password": "password", - "account_id": "acctname", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - "include_column_lineage": True, - "include_table_lineage": False, - } - ) + config_dict = default_config_dict.copy() + config_dict["include_column_lineage"] = True + config_dict["include_table_lineage"] = False + with pytest.raises( + ValidationError, + match="include_table_lineage must be True for include_column_lineage to be set", + ): + SnowflakeV2Config.parse_obj(config_dict) def test_snowflake_config_with_no_connect_args_returns_base_connect_args(): - config: SnowflakeV2Config = SnowflakeV2Config.parse_obj( - { - "username": "user", - "password": "password", - "account_id": "acctname", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - ) + config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(default_config_dict) assert config.get_options()["connect_args"] is not None assert config.get_options()["connect_args"] == { CLIENT_PREFETCH_THREADS: 10, @@ -300,7 +231,10 @@ def test_snowflake_config_with_no_connect_args_returns_base_connect_args(): def test_private_key_set_but_auth_not_changed(): - with pytest.raises(ValidationError): + with pytest.raises( + ValidationError, + match="Either `private_key` and `private_key_path` is set but `authentication_type` is DEFAULT_AUTHENTICATOR. Should be set to 'KEY_PAIR_AUTHENTICATOR' when using key pair authentication", + ): SnowflakeV2Config.parse_obj( { "account_id": "acctname", @@ -310,19 +244,11 @@ def test_private_key_set_but_auth_not_changed(): def test_snowflake_config_with_connect_args_overrides_base_connect_args(): - config: SnowflakeV2Config = SnowflakeV2Config.parse_obj( - { - "username": "user", - "password": "password", - "account_id": "acctname", - "database_pattern": {"allow": {"^demo$"}}, - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - "connect_args": { - CLIENT_PREFETCH_THREADS: 5, - }, - } - ) + config_dict = default_config_dict.copy() + config_dict["connect_args"] = { + CLIENT_PREFETCH_THREADS: 5, + } + config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(config_dict) assert config.get_options()["connect_args"] is not None assert config.get_options()["connect_args"][CLIENT_PREFETCH_THREADS] == 5 assert config.get_options()["connect_args"][CLIENT_SESSION_KEEP_ALIVE] is True @@ -331,35 +257,20 @@ def test_snowflake_config_with_connect_args_overrides_base_connect_args(): @patch("snowflake.connector.connect") def test_test_connection_failure(mock_connect): mock_connect.side_effect = Exception("Failed to connect to snowflake") - config = { - "username": "user", - "password": "password", - "account_id": "missing", - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - report = SnowflakeV2Source.test_connection(config) - assert report is not None - assert report.basic_connectivity - assert not report.basic_connectivity.capable - assert report.basic_connectivity.failure_reason - assert "Failed to connect to snowflake" in report.basic_connectivity.failure_reason + report = test_connection_helpers.run_test_connection( + SnowflakeV2Source, default_config_dict + ) + test_connection_helpers.assert_basic_connectivity_failure( + report, "Failed to connect to snowflake" + ) @patch("snowflake.connector.connect") def test_test_connection_basic_success(mock_connect): - config = { - "username": "user", - "password": "password", - "account_id": "missing", - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - report = SnowflakeV2Source.test_connection(config) - assert report is not None - assert report.basic_connectivity - assert report.basic_connectivity.capable - assert report.basic_connectivity.failure_reason is None + report = test_connection_helpers.run_test_connection( + SnowflakeV2Source, default_config_dict + ) + test_connection_helpers.assert_basic_connectivity_success(report) def setup_mock_connect(mock_connect, query_results=None): @@ -400,31 +311,18 @@ def query_results(query): return [] raise ValueError(f"Unexpected query: {query}") - config = { - "username": "user", - "password": "password", - "account_id": "missing", - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } setup_mock_connect(mock_connect, query_results) - report = SnowflakeV2Source.test_connection(config) - assert report is not None - assert report.basic_connectivity - assert report.basic_connectivity.capable - assert report.basic_connectivity.failure_reason is None - - assert report.capability_report - assert report.capability_report[SourceCapability.CONTAINERS].capable - assert not report.capability_report[SourceCapability.SCHEMA_METADATA].capable - failure_reason = report.capability_report[ - SourceCapability.SCHEMA_METADATA - ].failure_reason - assert failure_reason - - assert ( - "Current role TEST_ROLE does not have permissions to use warehouse" - in failure_reason + report = test_connection_helpers.run_test_connection( + SnowflakeV2Source, default_config_dict + ) + test_connection_helpers.assert_basic_connectivity_success(report) + + test_connection_helpers.assert_capability_report( + capability_report=report.capability_report, + success_capabilities=[SourceCapability.CONTAINERS], + failure_capabilities={ + SourceCapability.SCHEMA_METADATA: "Current role TEST_ROLE does not have permissions to use warehouse" + }, ) @@ -445,25 +343,17 @@ def query_results(query): setup_mock_connect(mock_connect, query_results) - config = { - "username": "user", - "password": "password", - "account_id": "missing", - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - report = SnowflakeV2Source.test_connection(config) - assert report is not None - assert report.basic_connectivity - assert report.basic_connectivity.capable - assert report.basic_connectivity.failure_reason is None - assert report.capability_report - - assert report.capability_report[SourceCapability.CONTAINERS].capable - assert not report.capability_report[SourceCapability.SCHEMA_METADATA].capable - assert ( - report.capability_report[SourceCapability.SCHEMA_METADATA].failure_reason - is not None + report = test_connection_helpers.run_test_connection( + SnowflakeV2Source, default_config_dict + ) + test_connection_helpers.assert_basic_connectivity_success(report) + + test_connection_helpers.assert_capability_report( + capability_report=report.capability_report, + success_capabilities=[SourceCapability.CONTAINERS], + failure_capabilities={ + SourceCapability.SCHEMA_METADATA: "Either no tables exist or current role does not have permissions to access them" + }, ) @@ -488,24 +378,19 @@ def query_results(query): setup_mock_connect(mock_connect, query_results) - config = { - "username": "user", - "password": "password", - "account_id": "missing", - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - report = SnowflakeV2Source.test_connection(config) - - assert report is not None - assert report.basic_connectivity - assert report.basic_connectivity.capable - assert report.basic_connectivity.failure_reason is None - assert report.capability_report - - assert report.capability_report[SourceCapability.CONTAINERS].capable - assert report.capability_report[SourceCapability.SCHEMA_METADATA].capable - assert report.capability_report[SourceCapability.DESCRIPTIONS].capable + report = test_connection_helpers.run_test_connection( + SnowflakeV2Source, default_config_dict + ) + test_connection_helpers.assert_basic_connectivity_success(report) + + test_connection_helpers.assert_capability_report( + capability_report=report.capability_report, + success_capabilities=[ + SourceCapability.CONTAINERS, + SourceCapability.SCHEMA_METADATA, + SourceCapability.DESCRIPTIONS, + ], + ) @patch("snowflake.connector.connect") @@ -538,25 +423,21 @@ def query_results(query): setup_mock_connect(mock_connect, query_results) - config = { - "username": "user", - "password": "password", - "account_id": "missing", - "warehouse": "COMPUTE_WH", - "role": "sysadmin", - } - report = SnowflakeV2Source.test_connection(config) - assert report is not None - assert report.basic_connectivity - assert report.basic_connectivity.capable - assert report.basic_connectivity.failure_reason is None - assert report.capability_report - - assert report.capability_report[SourceCapability.CONTAINERS].capable - assert report.capability_report[SourceCapability.SCHEMA_METADATA].capable - assert report.capability_report[SourceCapability.DATA_PROFILING].capable - assert report.capability_report[SourceCapability.DESCRIPTIONS].capable - assert report.capability_report[SourceCapability.LINEAGE_COARSE].capable + report = test_connection_helpers.run_test_connection( + SnowflakeV2Source, default_config_dict + ) + test_connection_helpers.assert_basic_connectivity_success(report) + + test_connection_helpers.assert_capability_report( + capability_report=report.capability_report, + success_capabilities=[ + SourceCapability.CONTAINERS, + SourceCapability.SCHEMA_METADATA, + SourceCapability.DATA_PROFILING, + SourceCapability.DESCRIPTIONS, + SourceCapability.LINEAGE_COARSE, + ], + ) def test_aws_cloud_region_from_snowflake_region_id(): @@ -610,11 +491,10 @@ def test_azure_cloud_region_from_snowflake_region_id(): def test_unknown_cloud_region_from_snowflake_region_id(): - with pytest.raises(Exception) as e: + with pytest.raises(Exception, match="Unknown snowflake region"): SnowflakeV2Source.get_cloud_region_from_snowflake_region_id( "somecloud_someregion" ) - assert "Unknown snowflake region" in str(e) def test_snowflake_object_access_entry_missing_object_id(): diff --git a/metadata-ingestion/tests/unit/test_sql_common.py b/metadata-ingestion/tests/unit/test_sql_common.py index e23d290b611f4..a98bf64171122 100644 --- a/metadata-ingestion/tests/unit/test_sql_common.py +++ b/metadata-ingestion/tests/unit/test_sql_common.py @@ -1,8 +1,7 @@ from typing import Dict -from unittest.mock import Mock +from unittest import mock import pytest -from sqlalchemy.engine.reflection import Inspector from datahub.ingestion.source.sql.sql_common import PipelineContext, SQLAlchemySource from datahub.ingestion.source.sql.sql_config import SQLCommonConfig @@ -13,19 +12,24 @@ class _TestSQLAlchemyConfig(SQLCommonConfig): def get_sql_alchemy_url(self): - pass + return "mysql+pymysql://user:pass@localhost:5330" class _TestSQLAlchemySource(SQLAlchemySource): - pass + @classmethod + def create(cls, config_dict, ctx): + config = _TestSQLAlchemyConfig.parse_obj(config_dict) + return cls(config, ctx, "TEST") + + +def get_test_sql_alchemy_source(): + return _TestSQLAlchemySource.create( + config_dict={}, ctx=PipelineContext(run_id="test_ctx") + ) def test_generate_foreign_key(): - config: SQLCommonConfig = _TestSQLAlchemyConfig() - ctx: PipelineContext = PipelineContext(run_id="test_ctx") - platform: str = "TEST" - inspector: Inspector = Mock() - source = _TestSQLAlchemySource(config=config, ctx=ctx, platform=platform) + source = get_test_sql_alchemy_source() fk_dict: Dict[str, str] = { "name": "test_constraint", "referred_table": "test_table", @@ -37,7 +41,7 @@ def test_generate_foreign_key(): dataset_urn="test_urn", schema="test_schema", fk_dict=fk_dict, - inspector=inspector, + inspector=mock.Mock(), ) assert fk_dict.get("name") == foreign_key.name @@ -48,11 +52,7 @@ def test_generate_foreign_key(): def test_use_source_schema_for_foreign_key_if_not_specified(): - config: SQLCommonConfig = _TestSQLAlchemyConfig() - ctx: PipelineContext = PipelineContext(run_id="test_ctx") - platform: str = "TEST" - inspector: Inspector = Mock() - source = _TestSQLAlchemySource(config=config, ctx=ctx, platform=platform) + source = get_test_sql_alchemy_source() fk_dict: Dict[str, str] = { "name": "test_constraint", "referred_table": "test_table", @@ -63,7 +63,7 @@ def test_use_source_schema_for_foreign_key_if_not_specified(): dataset_urn="test_urn", schema="test_schema", fk_dict=fk_dict, - inspector=inspector, + inspector=mock.Mock(), ) assert fk_dict.get("name") == foreign_key.name @@ -105,14 +105,32 @@ def test_get_platform_from_sqlalchemy_uri(uri: str, expected_platform: str) -> N def test_get_db_schema_with_dots_in_view_name(): - config: SQLCommonConfig = _TestSQLAlchemyConfig() - ctx: PipelineContext = PipelineContext(run_id="test_ctx") - platform: str = "TEST" - source = _TestSQLAlchemySource(config=config, ctx=ctx, platform=platform) - + source = get_test_sql_alchemy_source() database, schema = source.get_db_schema( dataset_identifier="database.schema.long.view.name1" ) - assert database == "database" assert schema == "schema" + + +def test_test_connection_success(): + source = get_test_sql_alchemy_source() + with mock.patch( + "datahub.ingestion.source.sql.sql_common.SQLAlchemySource.get_inspectors", + side_effect=lambda: [], + ): + report = source.test_connection({}) + assert report is not None + assert report.basic_connectivity + assert report.basic_connectivity.capable + assert report.basic_connectivity.failure_reason is None + + +def test_test_connection_failure(): + source = get_test_sql_alchemy_source() + report = source.test_connection({}) + assert report is not None + assert report.basic_connectivity + assert not report.basic_connectivity.capable + assert report.basic_connectivity.failure_reason + assert "Connection refused" in report.basic_connectivity.failure_reason