Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dbt): enable dbt read artifacts from s3 #4935

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metadata-ingestion/src/datahub/emitter/rest_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _make_curl_command(
*itertools.chain(
*[
("-X", method),
*[("-H", f"{k}: {v}") for (k, v) in session.headers.items()],
*[("-H", f"{k!s}: {v!s}") for (k, v) in session.headers.items()],
("--data", payload),
]
),
Expand Down
33 changes: 22 additions & 11 deletions metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from botocore.utils import fix_s3_host
from pydantic.fields import Field

from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.common import AllowDenyPattern, ConfigModel
from datahub.configuration.source_common import EnvBasedSourceConfigBase

if TYPE_CHECKING:
Expand Down Expand Up @@ -35,24 +35,16 @@ def assume_role(
return assumed_role_object["Credentials"]


class AwsSourceConfig(EnvBasedSourceConfigBase):
class AwsConnectionConfig(ConfigModel):
"""
Common AWS credentials config.

Currently used by:
- Glue source
- SageMaker source
- dbt source
"""

database_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for databases to filter in ingestion.",
)
table_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for tables to filter in ingestion.",
)

aws_access_key_id: Optional[str] = Field(
default=None,
description="Autodetected. See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html",
Expand Down Expand Up @@ -157,3 +149,22 @@ def get_glue_client(self) -> "GlueClient":

def get_sagemaker_client(self) -> "SageMakerClient":
return self.get_session().client("sagemaker")


class AwsSourceConfig(EnvBasedSourceConfigBase, AwsConnectionConfig):
"""
Common AWS credentials config.

Currently used by:
- Glue source
- SageMaker source
"""

database_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for databases to filter in ingestion.",
)
table_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="regex patterns for tables to filter in ingestion.",
)
187 changes: 111 additions & 76 deletions metadata-ingestion/src/datahub/ingestion/source/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast
from urllib.parse import urlparse

import dateutil.parser
import requests
Expand All @@ -23,6 +24,7 @@
)
from datahub.ingestion.api.ingestion_job_state_provider import JobId
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig
from datahub.ingestion.source.sql.sql_types import (
BIGQUERY_TYPES_MAP,
POSTGRES_TYPES_MAP,
Expand Down Expand Up @@ -172,6 +174,15 @@ class DBTConfig(StatefulIngestionConfigBase):
default=None,
description='Regex string to extract owner from the dbt node using the `(?P<name>...) syntax` of the [match object](https://docs.python.org/3/library/re.html#match-objects), where the group name must be `owner`. Examples: (1)`r"(?P<owner>(.*)): (\w+) (\w+)"` will extract `jdoe` as the owner from `"jdoe: John Doe"` (2) `r"@(?P<owner>(.*))"` will extract `alice` as the owner from `"@alice"`.', # noqa: W605
)
aws_connection: Optional[AwsConnectionConfig] = Field(
default=None,
description="When fetching manifest files from s3, configuration for aws connection details",
)

@property
def s3_client(self):
assert self.aws_connection
return self.aws_connection.get_s3_client()

# Custom Stateful Ingestion settings
stateful_ingestion: Optional[DBTStatefulIngestionConfig] = Field(
Expand All @@ -197,6 +208,22 @@ def validate_write_semantics(cls, write_semantics: str) -> str:
)
return write_semantics

@validator("aws_connection")
def aws_connection_needed_if_s3_uris_present(
cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any
) -> Optional[AwsConnectionConfig]:
# first check if there are fields that contain s3 uris
uri_containing_fields = [
f
for f in ["manifest_path", "catalog_path", "sources_path"]
if values.get(f, "").startswith("s3://")
]
if uri_containing_fields and not aws_connection:
raise ValueError(
f"Please provide aws_connection configuration, since s3 uris have been provided in fields {uri_containing_fields}"
)
return aws_connection


@dataclass
class DBTColumn:
Expand Down Expand Up @@ -387,80 +414,6 @@ def extract_dbt_entities(
return dbt_entities


def load_file_as_json(uri: str) -> Any:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
else:
with open(uri, "r") as f:
return json.load(f)


def loadManifestAndCatalog(
manifest_path: str,
catalog_path: str,
sources_path: Optional[str],
load_schemas: bool,
use_identifiers: bool,
tag_prefix: str,
node_type_pattern: AllowDenyPattern,
report: DBTSourceReport,
node_name_pattern: AllowDenyPattern,
) -> Tuple[
List[DBTNode],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Dict[str, Dict[str, Any]],
]:
dbt_manifest_json = load_file_as_json(manifest_path)

dbt_catalog_json = load_file_as_json(catalog_path)

if sources_path is not None:
dbt_sources_json = load_file_as_json(sources_path)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}

manifest_schema = dbt_manifest_json.get("metadata", {}).get("dbt_schema_version")
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")

catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")

manifest_nodes = dbt_manifest_json["nodes"]
manifest_sources = dbt_manifest_json["sources"]

all_manifest_entities = {**manifest_nodes, **manifest_sources}

catalog_nodes = dbt_catalog_json["nodes"]
catalog_sources = dbt_catalog_json["sources"]

all_catalog_entities = {**catalog_nodes, **catalog_sources}

nodes = extract_dbt_entities(
all_manifest_entities,
all_catalog_entities,
sources_results,
load_schemas,
use_identifiers,
tag_prefix,
node_type_pattern,
report,
node_name_pattern,
)

return (
nodes,
manifest_schema,
manifest_version,
catalog_schema,
catalog_version,
all_manifest_entities,
)


def get_db_fqn(database: Optional[str], schema: str, name: str) -> str:
if database is not None:
fqn = f"{database}.{schema}.{name}"
Expand Down Expand Up @@ -760,6 +713,88 @@ def soft_delete_item(urn: str, type: str) -> Iterable[MetadataWorkUnit]:
):
yield from soft_delete_item(table_urn, "dataset")

# s3://data-analysis.pelotime.com/dbt-artifacts/data-engineering-dbt/catalog.json
def load_file_as_json(self, uri: str) -> Any:
if re.match("^https?://", uri):
return json.loads(requests.get(uri).text)
elif re.match("^s3://", uri):
u = urlparse(uri)
response = self.config.s3_client.get_object(
Bucket=u.netloc, Key=u.path.lstrip("/")
)
return json.loads(response["Body"].read().decode("utf-8"))
else:
with open(uri, "r") as f:
return json.load(f)

def loadManifestAndCatalog(
self,
manifest_path: str,
catalog_path: str,
sources_path: Optional[str],
load_schemas: bool,
use_identifiers: bool,
tag_prefix: str,
node_type_pattern: AllowDenyPattern,
report: DBTSourceReport,
node_name_pattern: AllowDenyPattern,
) -> Tuple[
List[DBTNode],
Optional[str],
Optional[str],
Optional[str],
Optional[str],
Dict[str, Dict[str, Any]],
]:
dbt_manifest_json = self.load_file_as_json(manifest_path)

dbt_catalog_json = self.load_file_as_json(catalog_path)

if sources_path is not None:
dbt_sources_json = self.load_file_as_json(sources_path)
sources_results = dbt_sources_json["results"]
else:
sources_results = {}

manifest_schema = dbt_manifest_json.get("metadata", {}).get(
"dbt_schema_version"
)
manifest_version = dbt_manifest_json.get("metadata", {}).get("dbt_version")

catalog_schema = dbt_catalog_json.get("metadata", {}).get("dbt_schema_version")
catalog_version = dbt_catalog_json.get("metadata", {}).get("dbt_version")

manifest_nodes = dbt_manifest_json["nodes"]
manifest_sources = dbt_manifest_json["sources"]

all_manifest_entities = {**manifest_nodes, **manifest_sources}

catalog_nodes = dbt_catalog_json["nodes"]
catalog_sources = dbt_catalog_json["sources"]

all_catalog_entities = {**catalog_nodes, **catalog_sources}

nodes = extract_dbt_entities(
all_manifest_entities,
all_catalog_entities,
sources_results,
load_schemas,
use_identifiers,
tag_prefix,
node_type_pattern,
report,
node_name_pattern,
)

return (
nodes,
manifest_schema,
manifest_version,
catalog_schema,
catalog_version,
all_manifest_entities,
)

# create workunits from dbt nodes
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
if self.config.write_semantics == "PATCH" and not self.ctx.graph:
Expand All @@ -775,7 +810,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
catalog_schema,
catalog_version,
manifest_nodes_raw,
) = loadManifestAndCatalog(
) = self.loadManifestAndCatalog(
self.config.manifest_path,
self.config.catalog_path,
self.config.sources_path,
Expand Down Expand Up @@ -1342,7 +1377,7 @@ def get_platform_instance_id(self) -> str:
"""

project_id = (
load_file_as_json(self.config.manifest_path)
self.load_file_as_json(self.config.manifest_path)
.get("metadata", {})
.get("project_id")
)
Expand Down