Skip to content

Commit

Permalink
feat(ingestion/transformer): create tag if not exist (#9076)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddiquebagwan-gslab authored Dec 14, 2023
1 parent aac1c55 commit 0d6a5e5
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 19 deletions.
24 changes: 24 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/graph/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,9 +787,11 @@ def get_aspect_counts(self, aspect: str, urn_like: Optional[str] = None) -> int:

def execute_graphql(self, query: str, variables: Optional[Dict] = None) -> Dict:
url = f"{self.config.server}/api/graphql"

body: Dict = {
"query": query,
}

if variables:
body["variables"] = variables

Expand Down Expand Up @@ -1065,6 +1067,28 @@ def parse_sql_lineage(
default_schema=default_schema,
)

def create_tag(self, tag_name: str) -> str:
graph_query: str = """
mutation($tag_detail: CreateTagInput!) {
createTag(input: $tag_detail)
}
"""

variables = {
"tag_detail": {
"name": tag_name,
"id": tag_name,
},
}

res = self.execute_graphql(
query=graph_query,
variables=variables,
)

# return urn
return res["createTag"]

def close(self) -> None:
self._make_schema_resolver.cache_clear()
super().close()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import logging
from typing import Callable, List, Optional, cast

import datahub.emitter.mce_builder as builder
from datahub.configuration.common import (
KeyValuePattern,
TransformerSemanticsConfigModel,
)
from datahub.configuration.import_resolver import pydantic_resolve_key
from datahub.emitter.mce_builder import Aspect
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.transformer.dataset_transformer import DatasetTagsTransformer
from datahub.metadata.schema_classes import GlobalTagsClass, TagAssociationClass
from datahub.metadata.schema_classes import (
GlobalTagsClass,
TagAssociationClass,
TagKeyClass,
)
from datahub.utilities.urns.tag_urn import TagUrn

logger = logging.getLogger(__name__)


class AddDatasetTagsConfig(TransformerSemanticsConfigModel):
Expand All @@ -22,11 +32,13 @@ class AddDatasetTags(DatasetTagsTransformer):

ctx: PipelineContext
config: AddDatasetTagsConfig
processed_tags: List[TagAssociationClass]

def __init__(self, config: AddDatasetTagsConfig, ctx: PipelineContext):
super().__init__()
self.ctx = ctx
self.config = config
self.processed_tags = []

@classmethod
def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetTags":
Expand All @@ -45,11 +57,38 @@ def transform_aspect(
tags_to_add = self.config.get_tags_to_add(entity_urn)
if tags_to_add is not None:
out_global_tags_aspect.tags.extend(tags_to_add)
self.processed_tags.extend(
tags_to_add
) # Keep track of tags added so that we can create them in handle_end_of_stream

return self.get_result_semantics(
self.config, self.ctx.graph, entity_urn, out_global_tags_aspect
)

def handle_end_of_stream(self) -> List[MetadataChangeProposalWrapper]:

mcps: List[MetadataChangeProposalWrapper] = []

logger.debug("Generating tags")

for tag_association in self.processed_tags:
ids: List[str] = TagUrn.create_from_string(
tag_association.tag
).get_entity_id()

assert len(ids) == 1, "Invalid Tag Urn"

tag_name: str = ids[0]

mcps.append(
MetadataChangeProposalWrapper(
entityUrn=builder.make_tag_urn(tag=tag_name),
aspect=TagKeyClass(name=tag_name),
)
)

return mcps


class SimpleDatasetTagConfig(TransformerSemanticsConfigModel):
tag_urns: List[str]
Expand Down Expand Up @@ -82,6 +121,7 @@ class PatternAddDatasetTags(AddDatasetTags):
"""Transformer that adds a specified set of tags to each dataset."""

def __init__(self, config: PatternDatasetTagsConfig, ctx: PipelineContext):
config.tag_pattern.all
tag_pattern = config.tag_pattern
generic_config = AddDatasetTagsConfig(
get_tags_to_add=lambda _: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,30 @@
log = logging.getLogger(__name__)


class LegacyMCETransformer(Transformer, metaclass=ABCMeta):
def _update_work_unit_id(
envelope: RecordEnvelope, urn: str, aspect_name: str
) -> Dict[Any, Any]:
structured_urn = Urn.create_from_string(urn)
simple_name = "-".join(structured_urn.get_entity_id())
record_metadata = envelope.metadata.copy()
record_metadata.update({"workunit_id": f"txform-{simple_name}-{aspect_name}"})
return record_metadata


class HandleEndOfStreamTransformer:
def handle_end_of_stream(self) -> List[MetadataChangeProposalWrapper]:
return []


class LegacyMCETransformer(
Transformer, HandleEndOfStreamTransformer, metaclass=ABCMeta
):
@abstractmethod
def transform_one(self, mce: MetadataChangeEventClass) -> MetadataChangeEventClass:
pass


class SingleAspectTransformer(metaclass=ABCMeta):
class SingleAspectTransformer(HandleEndOfStreamTransformer, metaclass=ABCMeta):
@abstractmethod
def aspect_name(self) -> str:
"""Implement this method to specify a single aspect that the transformer is interested in subscribing to. No default provided."""
Expand Down Expand Up @@ -180,6 +197,32 @@ def _transform_or_record_mcpw(
self._record_mcp(envelope.record)
return envelope if envelope.record.aspect is not None else None

def _handle_end_of_stream(
self, envelope: RecordEnvelope
) -> Iterable[RecordEnvelope]:

if not isinstance(self, SingleAspectTransformer) and not isinstance(
self, LegacyMCETransformer
):
return

mcps: List[MetadataChangeProposalWrapper] = self.handle_end_of_stream()

for mcp in mcps:
if mcp.aspect is None or mcp.entityUrn is None: # to silent the lint error
continue

record_metadata = _update_work_unit_id(
envelope=envelope,
aspect_name=mcp.aspect.get_aspect_name(), # type: ignore
urn=mcp.entityUrn,
)

yield RecordEnvelope(
record=mcp,
metadata=record_metadata,
)

def transform(
self, record_envelopes: Iterable[RecordEnvelope]
) -> Iterable[RecordEnvelope]:
Expand Down Expand Up @@ -216,26 +259,32 @@ def transform(
else None,
)
if transformed_aspect:
# for end of stream records, we modify the workunit-id
structured_urn = Urn.create_from_string(urn)
simple_name = "-".join(structured_urn.get_entity_id())
record_metadata = envelope.metadata.copy()
record_metadata.update(
{
"workunit_id": f"txform-{simple_name}-{self.aspect_name()}"
}
)
yield RecordEnvelope(
record=MetadataChangeProposalWrapper(

mcp: MetadataChangeProposalWrapper = (
MetadataChangeProposalWrapper(
entityUrn=urn,
entityType=structured_urn.get_type(),
systemMetadata=last_seen_mcp.systemMetadata
if last_seen_mcp
else last_seen_mce_system_metadata,
aspectName=self.aspect_name(),
aspect=transformed_aspect,
),
)
)

record_metadata = _update_work_unit_id(
envelope=envelope,
aspect_name=mcp.aspect.get_aspect_name(), # type: ignore
urn=mcp.entityUrn,
)

yield RecordEnvelope(
record=mcp,
metadata=record_metadata,
)

self._mark_processed(urn)
yield from self._handle_end_of_stream(envelope=envelope)

yield envelope
32 changes: 27 additions & 5 deletions metadata-ingestion/tests/unit/test_transform_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,13 +813,25 @@ def test_simple_dataset_tags_transformation(mock_time):
]
)
)
assert len(outputs) == 3

assert len(outputs) == 5

# Check that tags were added.
tags_aspect = outputs[1].record.aspect
assert tags_aspect.tags[0].tag == builder.make_tag_urn("NeedsDocumentation")
assert tags_aspect
assert len(tags_aspect.tags) == 2
assert tags_aspect.tags[0].tag == builder.make_tag_urn("NeedsDocumentation")

# Check new tag entity should be there
assert outputs[2].record.aspectName == "tagKey"
assert outputs[2].record.aspect.name == "NeedsDocumentation"
assert outputs[2].record.entityUrn == builder.make_tag_urn("NeedsDocumentation")

assert outputs[3].record.aspectName == "tagKey"
assert outputs[3].record.aspect.name == "Legacy"
assert outputs[3].record.entityUrn == builder.make_tag_urn("Legacy")

assert isinstance(outputs[4].record, EndOfStream)


def dummy_tag_resolver_method(dataset_snapshot):
Expand Down Expand Up @@ -853,7 +865,7 @@ def test_pattern_dataset_tags_transformation(mock_time):
)
)

assert len(outputs) == 3
assert len(outputs) == 5
tags_aspect = outputs[1].record.aspect
assert tags_aspect
assert len(tags_aspect.tags) == 2
Expand Down Expand Up @@ -1363,7 +1375,7 @@ def test_mcp_add_tags_missing(mock_time):
]
input_stream.append(RecordEnvelope(record=EndOfStream(), metadata={}))
outputs = list(transformer.transform(input_stream))
assert len(outputs) == 3
assert len(outputs) == 5
assert outputs[0].record == dataset_mcp
# Check that tags were added, this will be the second result
tags_aspect = outputs[1].record.aspect
Expand Down Expand Up @@ -1395,13 +1407,23 @@ def test_mcp_add_tags_existing(mock_time):
]
input_stream.append(RecordEnvelope(record=EndOfStream(), metadata={}))
outputs = list(transformer.transform(input_stream))
assert len(outputs) == 2

assert len(outputs) == 4

# Check that tags were added, this will be the second result
tags_aspect = outputs[0].record.aspect
assert tags_aspect
assert len(tags_aspect.tags) == 3
assert tags_aspect.tags[0].tag == builder.make_tag_urn("Test")
assert tags_aspect.tags[1].tag == builder.make_tag_urn("NeedsDocumentation")
assert tags_aspect.tags[2].tag == builder.make_tag_urn("Legacy")

# Check tag entities got added
assert outputs[1].record.entityType == "tag"
assert outputs[1].record.entityUrn == builder.make_tag_urn("NeedsDocumentation")
assert outputs[2].record.entityType == "tag"
assert outputs[2].record.entityUrn == builder.make_tag_urn("Legacy")

assert isinstance(outputs[-1].record, EndOfStream)


Expand Down

0 comments on commit 0d6a5e5

Please sign in to comment.