Skip to content

Commit

Permalink
feat(ingest): support dynamic imports for transfomer methods (#2858)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Jul 12, 2021
1 parent 973c08d commit 220dfe7
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 4 deletions.
13 changes: 13 additions & 0 deletions metadata-ingestion/src/datahub/configuration/import_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pydantic

from datahub.ingestion.api.registry import import_key


def _pydantic_resolver(v):
if isinstance(v, str):
return import_key(v)
return v


def pydantic_resolve_key(field):
return pydantic.validator(field, pre=True, allow_reuse=True)(_pydantic_resolver)
12 changes: 9 additions & 3 deletions metadata-ingestion/src/datahub/ingestion/api/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
import inspect
from typing import Dict, Generic, Type, TypeVar, Union
from typing import Any, Dict, Generic, Type, TypeVar, Union

import entrypoints
import typing_inspect
Expand All @@ -11,6 +11,13 @@
T = TypeVar("T")


def import_key(key: str) -> Any:
assert "." in key, "import key must contain a ."
module_name, item_name = key.rsplit(".", 1)
item = getattr(importlib.import_module(module_name), item_name)
return item


class Registry(Generic[T]):
def __init__(self):
self._mapping: Dict[str, Union[Type[T], Exception]] = {}
Expand Down Expand Up @@ -68,8 +75,7 @@ def get(self, key: str) -> Type[T]:
if key.find(".") >= 0:
# If the key contains a dot, we treat it as a import path and attempt
# to load it dynamically.
module_name, class_name = key.rsplit(".", 1)
MyClass = getattr(importlib.import_module(module_name), class_name)
MyClass = import_key(key)
self._check_cls(MyClass)
return MyClass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datahub.emitter.mce_builder as builder
from datahub.configuration.common import ConfigModel
from datahub.configuration.import_resolver import pydantic_resolve_key
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.transform import Transformer
from datahub.metadata.schema_classes import (
Expand All @@ -22,6 +23,8 @@ class AddDatasetOwnershipConfig(ConfigModel):
]
default_actor: str = builder.make_user_urn("etl")

_resolve_owner_fn = pydantic_resolve_key("get_owners_to_add")


class AddDatasetOwnership(Transformer):
"""Transformer that adds owners to datasets according to a callback function."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datahub.emitter.mce_builder as builder
from datahub.configuration.common import ConfigModel
from datahub.configuration.import_resolver import pydantic_resolve_key
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.transform import Transformer
from datahub.metadata.schema_classes import (
Expand All @@ -20,6 +21,8 @@ class AddDatasetTagsConfig(ConfigModel):
Callable[[DatasetSnapshotClass], List[TagAssociationClass]],
]

_resolve_tag_fn = pydantic_resolve_key("get_tags_to_add")


class AddDatasetTags(Transformer):
"""Transformer that adds tags to datasets according to a callback function."""
Expand Down
24 changes: 23 additions & 1 deletion metadata-ingestion/tests/unit/test_transform_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from datahub.ingestion.transformer.add_dataset_ownership import (
SimpleAddDatasetOwnership,
)
from datahub.ingestion.transformer.add_dataset_tags import SimpleAddDatasetTags
from datahub.ingestion.transformer.add_dataset_tags import (
AddDatasetTags,
SimpleAddDatasetTags,
)


def make_generic_dataset():
Expand Down Expand Up @@ -120,3 +123,22 @@ def test_simple_dataset_tags_transformation(mock_time):
assert tags_aspect
assert len(tags_aspect.tags) == 2
assert tags_aspect.tags[0].tag == builder.make_tag_urn("NeedsDocumentation")


def dummy_tag_resolver_method(dataset_snapshot):
return []


def test_import_resolver():
transformer = AddDatasetTags.create(
{
"get_tags_to_add": "tests.unit.test_transform_dataset.dummy_tag_resolver_method"
},
PipelineContext(run_id="test-tags"),
)
output = list(
transformer.transform(
[RecordEnvelope(input, metadata={}) for input in [make_generic_dataset()]]
)
)
assert output

0 comments on commit 220dfe7

Please sign in to comment.