Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingestion): Add test_connection methods for important sources #9334

Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 60 additions & 29 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.source.dbt.dbt_common import (
DBTColumn,
DBTCommonConfig,
Expand Down Expand Up @@ -177,7 +182,7 @@ class DBTCloudConfig(DBTCommonConfig):
@support_status(SupportStatus.INCUBATING)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCloudSource(DBTSourceBase):
class DBTCloudSource(DBTSourceBase, TestableSource):
"""
This source pulls dbt metadata directly from the dbt Cloud APIs.

Expand All @@ -199,6 +204,57 @@ def create(cls, config_dict, ctx):
config = DBTCloudConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt")

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = DBTCloudConfig.parse_obj_allow_extras(config_dict)
DBTCloudSource._send_graphql_query(
metadata_endpoint=source_config.metadata_endpoint,
token=source_config.token,
query=_DBT_GRAPHQL_QUERY.format(type="tests", fields="jobId"),
variables={
"jobId": source_config.job_id,
"runId": source_config.run_id,
},
)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

@staticmethod
def _send_graphql_query(
metadata_endpoint: str, token: str, query: str, variables: Dict
) -> Dict:
logger.debug(f"Sending GraphQL query to dbt Cloud: {query}")
response = requests.post(
metadata_endpoint,
json={
"query": query,
"variables": variables,
},
headers={
"Authorization": f"Bearer {token}",
"X-dbt-partner-source": "acryldatahub",
},
)

try:
res = response.json()
if "errors" in res:
raise ValueError(
f'Unable to fetch metadata from dbt Cloud: {res["errors"]}'
)
data = res["data"]
except JSONDecodeError as e:
response.raise_for_status()
raise e

return data

def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
# TODO: In dbt Cloud, commands are scheduled as part of jobs, where
# each job can have multiple runs. We currently only fully support
Expand All @@ -213,6 +269,8 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
for node_type, fields in _DBT_FIELDS_BY_TYPE.items():
logger.info(f"Fetching {node_type} from dbt Cloud")
data = self._send_graphql_query(
metadata_endpoint=self.config.metadata_endpoint,
token=self.config.token,
query=_DBT_GRAPHQL_QUERY.format(type=node_type, fields=fields),
variables={
"jobId": self.config.job_id,
Expand All @@ -232,33 +290,6 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:

return nodes, additional_metadata

def _send_graphql_query(self, query: str, variables: Dict) -> Dict:
logger.debug(f"Sending GraphQL query to dbt Cloud: {query}")
response = requests.post(
self.config.metadata_endpoint,
json={
"query": query,
"variables": variables,
},
headers={
"Authorization": f"Bearer {self.config.token}",
"X-dbt-partner-source": "acryldatahub",
},
)

try:
res = response.json()
if "errors" in res:
raise ValueError(
f'Unable to fetch metadata from dbt Cloud: {res["errors"]}'
)
data = res["data"]
except JSONDecodeError as e:
response.raise_for_status()
raise e

return data

def _parse_into_dbt_node(self, node: Dict) -> DBTNode:
key = node["uniqueId"]

Expand Down
56 changes: 43 additions & 13 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.dbt.dbt_common import (
DBTColumn,
Expand Down Expand Up @@ -60,11 +65,6 @@ class DBTCoreConfig(DBTCommonConfig):

_github_info_deprecated = pydantic_renamed_field("github_info", "git_info")

@property
def s3_client(self):
assert self.aws_connection
return self.aws_connection.get_s3_client()

@validator("aws_connection")
def aws_connection_needed_if_s3_uris_present(
cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any
Expand Down Expand Up @@ -363,7 +363,7 @@ def load_test_results(
@support_status(SupportStatus.CERTIFIED)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCoreSource(DBTSourceBase):
class DBTCoreSource(DBTSourceBase, TestableSource):
"""
The artifacts used by this source are:
- [dbt manifest file](https://docs.getdbt.com/reference/artifacts/manifest-json)
Expand All @@ -387,12 +387,34 @@ def create(cls, config_dict, ctx):
config = DBTCoreConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt")

def load_file_as_json(self, uri: str) -> Any:
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = DBTCoreConfig.parse_obj_allow_extras(config_dict)
DBTCoreSource.load_file_as_json(
source_config.manifest_path, source_config.aws_connection
)
DBTCoreSource.load_file_as_json(
source_config.catalog_path, source_config.aws_connection
)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

@staticmethod
def load_file_as_json(
uri: str, aws_connection: Optional[AwsConnectionConfig]
) -> Dict:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
elif re.match("^s3://", uri):
u = urlparse(uri)
response = self.config.s3_client.get_object(
assert aws_connection
response = aws_connection.get_s3_client().get_object(
shubhamjagtap639 marked this conversation as resolved.
Show resolved Hide resolved
Bucket=u.netloc, Key=u.path.lstrip("/")
)
return json.loads(response["Body"].read().decode("utf-8"))
Expand All @@ -410,12 +432,18 @@ def loadManifestAndCatalog(
Optional[str],
Optional[str],
]:
dbt_manifest_json = self.load_file_as_json(self.config.manifest_path)
dbt_manifest_json = self.load_file_as_json(
self.config.manifest_path, self.config.aws_connection
)

dbt_catalog_json = self.load_file_as_json(self.config.catalog_path)
dbt_catalog_json = self.load_file_as_json(
self.config.catalog_path, self.config.aws_connection
)

if self.config.sources_path is not None:
dbt_sources_json = self.load_file_as_json(self.config.sources_path)
dbt_sources_json = self.load_file_as_json(
self.config.sources_path, self.config.aws_connection
)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}
Expand Down Expand Up @@ -478,7 +506,9 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
# This will populate the test_results field on each test node.
all_nodes = load_test_results(
self.config,
self.load_file_as_json(self.config.test_results_path),
self.load_file_as_json(
self.config.test_results_path, self.config.aws_connection
),
all_nodes,
)

Expand Down
74 changes: 66 additions & 8 deletions metadata-ingestion/src/datahub/ingestion/source/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ConfigResource,
TopicMetadata,
)
from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient

from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.kafka import KafkaConsumerConnectionConfig
Expand All @@ -40,7 +41,13 @@
support_status,
)
from datahub.ingestion.api.registry import import_path
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
MetadataWorkUnitProcessor,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
Expand Down Expand Up @@ -133,6 +140,18 @@ class KafkaSourceConfig(
)


def get_kafka_consumer(
connection: KafkaConsumerConnectionConfig,
) -> confluent_kafka.Consumer:
return confluent_kafka.Consumer(
{
"group.id": "test",
"bootstrap.servers": connection.bootstrap,
**connection.consumer_config,
}
)


@dataclass
class KafkaSourceReport(StaleEntityRemovalSourceReport):
topics_scanned: int = 0
Expand All @@ -145,6 +164,45 @@ def report_dropped(self, topic: str) -> None:
self.filtered.append(topic)


class KafkaConnectionTest:
def __init__(self, config_dict: dict):
self.config = KafkaSourceConfig.parse_obj_allow_extras(config_dict)
self.report = KafkaSourceReport()
self.consumer: confluent_kafka.Consumer = get_kafka_consumer(
self.config.connection
)

def get_connection_test(self) -> TestConnectionReport:
capability_report = {
SourceCapability.SCHEMA_METADATA: self.schema_registry_connectivity(),
}
return TestConnectionReport(
basic_connectivity=self.basic_connectivity(),
capability_report={
k: v for k, v in capability_report.items() if v is not None
},
)

def basic_connectivity(self) -> CapabilityReport:
try:
self.consumer.list_topics(timeout=10)
return CapabilityReport(capable=True)
except Exception as e:
return CapabilityReport(capable=False, failure_reason=str(e))

def schema_registry_connectivity(self) -> CapabilityReport:
try:
SchemaRegistryClient(
{
"url": self.config.connection.schema_registry_url,
**self.config.connection.schema_registry_config,
}
).get_subjects()
return CapabilityReport(capable=True)
except Exception as e:
return CapabilityReport(capable=False, failure_reason=str(e))


@platform_name("Kafka")
@config_class(KafkaSourceConfig)
@support_status(SupportStatus.CERTIFIED)
Expand All @@ -160,7 +218,7 @@ def report_dropped(self, topic: str) -> None:
SourceCapability.SCHEMA_METADATA,
"Schemas associated with each topic are extracted from the schema registry. Avro and Protobuf (certified), JSON (incubating). Schema references are supported.",
)
class KafkaSource(StatefulIngestionSourceBase):
class KafkaSource(StatefulIngestionSourceBase, TestableSource):
"""
This plugin extracts the following:
- Topics from the Kafka broker
Expand All @@ -183,12 +241,8 @@ def create_schema_registry(
def __init__(self, config: KafkaSourceConfig, ctx: PipelineContext):
super().__init__(config, ctx)
self.source_config: KafkaSourceConfig = config
self.consumer: confluent_kafka.Consumer = confluent_kafka.Consumer(
{
"group.id": "test",
"bootstrap.servers": self.source_config.connection.bootstrap,
**self.source_config.connection.consumer_config,
}
self.consumer: confluent_kafka.Consumer = get_kafka_consumer(
self.source_config.connection
)
self.init_kafka_admin_client()
self.report: KafkaSourceReport = KafkaSourceReport()
Expand Down Expand Up @@ -226,6 +280,10 @@ def init_kafka_admin_client(self) -> None:
f"Failed to create Kafka Admin Client due to error {e}.",
)

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
return KafkaConnectionTest(config_dict).get_connection_test()

@classmethod
def create(cls, config_dict: Dict, ctx: PipelineContext) -> "KafkaSource":
config: KafkaSourceConfig = KafkaSourceConfig.parse_obj(config_dict)
Expand Down
22 changes: 20 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceReport
from datahub.ingestion.api.source import (
CapabilityReport,
MetadataWorkUnitProcessor,
SourceReport,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.source_helpers import auto_workunit
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import (
Expand Down Expand Up @@ -1143,7 +1149,7 @@ def report_to_datahub_work_units(
SourceCapability.LINEAGE_FINE,
"Disabled by default, configured using `extract_column_level_lineage`. ",
)
class PowerBiDashboardSource(StatefulIngestionSourceBase):
class PowerBiDashboardSource(StatefulIngestionSourceBase, TestableSource):
"""
This plugin extracts the following:
- Power BI dashboards, tiles and datasets
Expand Down Expand Up @@ -1182,6 +1188,18 @@ def __init__(self, config: PowerBiDashboardSourceConfig, ctx: PipelineContext):
self, self.source_config, self.ctx
)

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
PowerBiAPI(PowerBiDashboardSourceConfig.parse_obj_allow_extras(config_dict))
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

@classmethod
def create(cls, config_dict, ctx):
config = PowerBiDashboardSourceConfig.parse_obj(config_dict)
Expand Down
Loading
Loading