Skip to content

Commit

Permalink
feat(ingestion): Add test_connection methods for important sources (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamjagtap639 authored and Salman-Apptware committed Dec 15, 2023
1 parent c3d217c commit 0565b26
Show file tree
Hide file tree
Showing 15 changed files with 684 additions and 381 deletions.
89 changes: 60 additions & 29 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.source.dbt.dbt_common import (
DBTColumn,
DBTCommonConfig,
Expand Down Expand Up @@ -177,7 +182,7 @@ class DBTCloudConfig(DBTCommonConfig):
@support_status(SupportStatus.INCUBATING)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCloudSource(DBTSourceBase):
class DBTCloudSource(DBTSourceBase, TestableSource):
"""
This source pulls dbt metadata directly from the dbt Cloud APIs.
Expand All @@ -199,6 +204,57 @@ def create(cls, config_dict, ctx):
config = DBTCloudConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt")

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = DBTCloudConfig.parse_obj_allow_extras(config_dict)
DBTCloudSource._send_graphql_query(
metadata_endpoint=source_config.metadata_endpoint,
token=source_config.token,
query=_DBT_GRAPHQL_QUERY.format(type="tests", fields="jobId"),
variables={
"jobId": source_config.job_id,
"runId": source_config.run_id,
},
)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

@staticmethod
def _send_graphql_query(
metadata_endpoint: str, token: str, query: str, variables: Dict
) -> Dict:
logger.debug(f"Sending GraphQL query to dbt Cloud: {query}")
response = requests.post(
metadata_endpoint,
json={
"query": query,
"variables": variables,
},
headers={
"Authorization": f"Bearer {token}",
"X-dbt-partner-source": "acryldatahub",
},
)

try:
res = response.json()
if "errors" in res:
raise ValueError(
f'Unable to fetch metadata from dbt Cloud: {res["errors"]}'
)
data = res["data"]
except JSONDecodeError as e:
response.raise_for_status()
raise e

return data

def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
# TODO: In dbt Cloud, commands are scheduled as part of jobs, where
# each job can have multiple runs. We currently only fully support
Expand All @@ -213,6 +269,8 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
for node_type, fields in _DBT_FIELDS_BY_TYPE.items():
logger.info(f"Fetching {node_type} from dbt Cloud")
data = self._send_graphql_query(
metadata_endpoint=self.config.metadata_endpoint,
token=self.config.token,
query=_DBT_GRAPHQL_QUERY.format(type=node_type, fields=fields),
variables={
"jobId": self.config.job_id,
Expand All @@ -232,33 +290,6 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:

return nodes, additional_metadata

def _send_graphql_query(self, query: str, variables: Dict) -> Dict:
logger.debug(f"Sending GraphQL query to dbt Cloud: {query}")
response = requests.post(
self.config.metadata_endpoint,
json={
"query": query,
"variables": variables,
},
headers={
"Authorization": f"Bearer {self.config.token}",
"X-dbt-partner-source": "acryldatahub",
},
)

try:
res = response.json()
if "errors" in res:
raise ValueError(
f'Unable to fetch metadata from dbt Cloud: {res["errors"]}'
)
data = res["data"]
except JSONDecodeError as e:
response.raise_for_status()
raise e

return data

def _parse_into_dbt_node(self, node: Dict) -> DBTNode:
key = node["uniqueId"]

Expand Down
56 changes: 43 additions & 13 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.dbt.dbt_common import (
DBTColumn,
Expand Down Expand Up @@ -60,11 +65,6 @@ class DBTCoreConfig(DBTCommonConfig):

_github_info_deprecated = pydantic_renamed_field("github_info", "git_info")

@property
def s3_client(self):
assert self.aws_connection
return self.aws_connection.get_s3_client()

@validator("aws_connection")
def aws_connection_needed_if_s3_uris_present(
cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any
Expand Down Expand Up @@ -363,7 +363,7 @@ def load_test_results(
@support_status(SupportStatus.CERTIFIED)
@capability(SourceCapability.DELETION_DETECTION, "Enabled via stateful ingestion")
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class DBTCoreSource(DBTSourceBase):
class DBTCoreSource(DBTSourceBase, TestableSource):
"""
The artifacts used by this source are:
- [dbt manifest file](https://docs.getdbt.com/reference/artifacts/manifest-json)
Expand All @@ -387,12 +387,34 @@ def create(cls, config_dict, ctx):
config = DBTCoreConfig.parse_obj(config_dict)
return cls(config, ctx, "dbt")

def load_file_as_json(self, uri: str) -> Any:
@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
source_config = DBTCoreConfig.parse_obj_allow_extras(config_dict)
DBTCoreSource.load_file_as_json(
source_config.manifest_path, source_config.aws_connection
)
DBTCoreSource.load_file_as_json(
source_config.catalog_path, source_config.aws_connection
)
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

@staticmethod
def load_file_as_json(
uri: str, aws_connection: Optional[AwsConnectionConfig]
) -> Dict:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
elif re.match("^s3://", uri):
u = urlparse(uri)
response = self.config.s3_client.get_object(
assert aws_connection
response = aws_connection.get_s3_client().get_object(
Bucket=u.netloc, Key=u.path.lstrip("/")
)
return json.loads(response["Body"].read().decode("utf-8"))
Expand All @@ -410,12 +432,18 @@ def loadManifestAndCatalog(
Optional[str],
Optional[str],
]:
dbt_manifest_json = self.load_file_as_json(self.config.manifest_path)
dbt_manifest_json = self.load_file_as_json(
self.config.manifest_path, self.config.aws_connection
)

dbt_catalog_json = self.load_file_as_json(self.config.catalog_path)
dbt_catalog_json = self.load_file_as_json(
self.config.catalog_path, self.config.aws_connection
)

if self.config.sources_path is not None:
dbt_sources_json = self.load_file_as_json(self.config.sources_path)
dbt_sources_json = self.load_file_as_json(
self.config.sources_path, self.config.aws_connection
)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}
Expand Down Expand Up @@ -491,7 +519,9 @@ def load_nodes(self) -> Tuple[List[DBTNode], Dict[str, Optional[str]]]:
# This will populate the test_results field on each test node.
all_nodes = load_test_results(
self.config,
self.load_file_as_json(self.config.test_results_path),
self.load_file_as_json(
self.config.test_results_path, self.config.aws_connection
),
all_nodes,
)

Expand Down
74 changes: 66 additions & 8 deletions metadata-ingestion/src/datahub/ingestion/source/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ConfigResource,
TopicMetadata,
)
from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient

from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.kafka import KafkaConsumerConnectionConfig
Expand All @@ -40,7 +41,13 @@
support_status,
)
from datahub.ingestion.api.registry import import_path
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceCapability
from datahub.ingestion.api.source import (
CapabilityReport,
MetadataWorkUnitProcessor,
SourceCapability,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.kafka_schema_registry_base import KafkaSchemaRegistryBase
Expand Down Expand Up @@ -133,6 +140,18 @@ class KafkaSourceConfig(
)


def get_kafka_consumer(
connection: KafkaConsumerConnectionConfig,
) -> confluent_kafka.Consumer:
return confluent_kafka.Consumer(
{
"group.id": "test",
"bootstrap.servers": connection.bootstrap,
**connection.consumer_config,
}
)


@dataclass
class KafkaSourceReport(StaleEntityRemovalSourceReport):
topics_scanned: int = 0
Expand All @@ -145,6 +164,45 @@ def report_dropped(self, topic: str) -> None:
self.filtered.append(topic)


class KafkaConnectionTest:
def __init__(self, config_dict: dict):
self.config = KafkaSourceConfig.parse_obj_allow_extras(config_dict)
self.report = KafkaSourceReport()
self.consumer: confluent_kafka.Consumer = get_kafka_consumer(
self.config.connection
)

def get_connection_test(self) -> TestConnectionReport:
capability_report = {
SourceCapability.SCHEMA_METADATA: self.schema_registry_connectivity(),
}
return TestConnectionReport(
basic_connectivity=self.basic_connectivity(),
capability_report={
k: v for k, v in capability_report.items() if v is not None
},
)

def basic_connectivity(self) -> CapabilityReport:
try:
self.consumer.list_topics(timeout=10)
return CapabilityReport(capable=True)
except Exception as e:
return CapabilityReport(capable=False, failure_reason=str(e))

def schema_registry_connectivity(self) -> CapabilityReport:
try:
SchemaRegistryClient(
{
"url": self.config.connection.schema_registry_url,
**self.config.connection.schema_registry_config,
}
).get_subjects()
return CapabilityReport(capable=True)
except Exception as e:
return CapabilityReport(capable=False, failure_reason=str(e))


@platform_name("Kafka")
@config_class(KafkaSourceConfig)
@support_status(SupportStatus.CERTIFIED)
Expand All @@ -160,7 +218,7 @@ def report_dropped(self, topic: str) -> None:
SourceCapability.SCHEMA_METADATA,
"Schemas associated with each topic are extracted from the schema registry. Avro and Protobuf (certified), JSON (incubating). Schema references are supported.",
)
class KafkaSource(StatefulIngestionSourceBase):
class KafkaSource(StatefulIngestionSourceBase, TestableSource):
"""
This plugin extracts the following:
- Topics from the Kafka broker
Expand All @@ -183,12 +241,8 @@ def create_schema_registry(
def __init__(self, config: KafkaSourceConfig, ctx: PipelineContext):
super().__init__(config, ctx)
self.source_config: KafkaSourceConfig = config
self.consumer: confluent_kafka.Consumer = confluent_kafka.Consumer(
{
"group.id": "test",
"bootstrap.servers": self.source_config.connection.bootstrap,
**self.source_config.connection.consumer_config,
}
self.consumer: confluent_kafka.Consumer = get_kafka_consumer(
self.source_config.connection
)
self.init_kafka_admin_client()
self.report: KafkaSourceReport = KafkaSourceReport()
Expand Down Expand Up @@ -226,6 +280,10 @@ def init_kafka_admin_client(self) -> None:
f"Failed to create Kafka Admin Client due to error {e}.",
)

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
return KafkaConnectionTest(config_dict).get_connection_test()

@classmethod
def create(cls, config_dict: Dict, ctx: PipelineContext) -> "KafkaSource":
config: KafkaSourceConfig = KafkaSourceConfig.parse_obj(config_dict)
Expand Down
22 changes: 20 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, SourceReport
from datahub.ingestion.api.source import (
CapabilityReport,
MetadataWorkUnitProcessor,
SourceReport,
TestableSource,
TestConnectionReport,
)
from datahub.ingestion.api.source_helpers import auto_workunit
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import (
Expand Down Expand Up @@ -1147,7 +1153,7 @@ def report_to_datahub_work_units(
SourceCapability.LINEAGE_FINE,
"Disabled by default, configured using `extract_column_level_lineage`. ",
)
class PowerBiDashboardSource(StatefulIngestionSourceBase):
class PowerBiDashboardSource(StatefulIngestionSourceBase, TestableSource):
"""
This plugin extracts the following:
- Power BI dashboards, tiles and datasets
Expand Down Expand Up @@ -1186,6 +1192,18 @@ def __init__(self, config: PowerBiDashboardSourceConfig, ctx: PipelineContext):
self, self.source_config, self.ctx
)

@staticmethod
def test_connection(config_dict: dict) -> TestConnectionReport:
test_report = TestConnectionReport()
try:
PowerBiAPI(PowerBiDashboardSourceConfig.parse_obj_allow_extras(config_dict))
test_report.basic_connectivity = CapabilityReport(capable=True)
except Exception as e:
test_report.basic_connectivity = CapabilityReport(
capable=False, failure_reason=str(e)
)
return test_report

@classmethod
def create(cls, config_dict, ctx):
config = PowerBiDashboardSourceConfig.parse_obj(config_dict)
Expand Down
Loading

0 comments on commit 0565b26

Please sign in to comment.