Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DA][testing] RFC: Create AutomationConditionTester class #22292

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python_modules/dagster/dagster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
from dagster._core.definitions.declarative_automation import (
AssetCondition as AssetCondition,
AutomationCondition as AutomationCondition,
evaluate_automation_conditions as evaluate_automation_conditions,
)
from dagster._core.definitions.decorators.asset_check_decorator import (
asset_check as asset_check,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .automation_condition import AutomationCondition as AutomationCondition
from .automation_condition_tester import (
evaluate_automation_conditions as evaluate_automation_conditions,
)
from .legacy import RuleCondition as RuleCondition
from .legacy.asset_condition import AssetCondition as AssetCondition
from .operands import (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import datetime
import logging
from collections import defaultdict
from functools import cached_property
from typing import AbstractSet, Mapping, Optional, Sequence, Union

from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView
from dagster._core.definitions.asset_daemon_cursor import AssetDaemonCursor
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.asset_selection import AssetSelection
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.data_time import CachingDataTimeResolver
from dagster._core.definitions.declarative_automation.automation_condition_evaluator import (
AutomationConditionEvaluator,
)
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.events import AssetKeyPartitionKey
from dagster._core.instance import DagsterInstance
from dagster._seven import get_current_datetime_in_utc


class EvaluateAutomationConditionsResult:
def __init__(
self,
requested_asset_partitions: AbstractSet[AssetKeyPartitionKey],
cursor: AssetDaemonCursor,
):
self._requested_asset_partitions = requested_asset_partitions
self.cursor = cursor

@cached_property
def _requested_partitions_by_asset_key(self) -> Mapping[AssetKey, AbstractSet[Optional[str]]]:
mapping = defaultdict(set)
for asset_partition in self._requested_asset_partitions:
mapping[asset_partition.asset_key].add(asset_partition.partition_key)
return mapping

@property
def total_requested(self) -> int:
"""Returns the total number of asset partitions requested during this evaluation."""
return len(self._requested_asset_partitions)

def get_requested_partitions(self, asset_key: AssetKey) -> AbstractSet[Optional[str]]:
"""Returns the specific partition keys requested for the given asset during this evaluation."""
return self._requested_partitions_by_asset_key[asset_key]

def get_num_requested(self, asset_key: AssetKey) -> int:
"""Returns the number of asset partitions requested for the given asset during this evaluation."""
return len(self.get_requested_partitions(asset_key))


def evaluate_automation_conditions(
defs: Union[Definitions, Sequence[AssetsDefinition]],
instance: DagsterInstance,
asset_selection: AssetSelection = AssetSelection.all(),
evaluation_time: Optional[datetime.datetime] = None,
cursor: Optional[AssetDaemonCursor] = None,
) -> EvaluateAutomationConditionsResult:
"""Evaluates the AutomationConditions of the provided assets, returning the results. Intended
for use in unit tests.

Params:
defs (Union[Definitions, Sequence[AssetsDefinitions]]):
The definitions to evaluate the conditions of.
instance (DagsterInstance):
The instance to evaluate against.
asset_selection (AssetSelection):
The selection of assets within defs to evaluate against. Defaults to AssetSelection.all()
evaluation_time (Optional[datetime.datetime]):
The time to use for the evaluation. Defaults to the true current time.
cursor (Optional[AssetDaemonCursor]):
The cursor for the computation. If you are evaluating multiple ticks within a test, this
value should be supplied from the `cursor` property of the returned `result` object.

Examples:
.. code-block:: python

from dagster import DagsterInstance, evaluate_automation_conditions

from my_proj import defs

def test_my_automation_conditions() -> None:

instance = DagsterInstance.ephemeral()

# asset starts off as missing, expect it to be requested
result = evaluate_automation_conditions(defs, instance)
assert result.total_requested == 1

# don't re-request the same asset
result = evaluate_automation_conditions(defs, instance, cursor=cursor)
assert result.total_requested == 0


from dagster import AssetExecutionContext
from dagster_dbt import DbtCliResource, dbt_assets


@dbt_assets(manifest=Path("target", "manifest.json"))
def my_dbt_assets(context: AssetExecutionContext, dbt: DbtCliResource):
yield from dbt.cli(["build"], context=context).stream()
"""
if not isinstance(defs, Definitions):
defs = Definitions(assets=defs)

asset_graph_view = AssetGraphView.for_test(
defs=defs,
instance=instance,
effective_dt=evaluation_time or get_current_datetime_in_utc(),
last_event_id=instance.event_log_storage.get_maximum_record_id(),
)
asset_graph = defs.get_asset_graph()
data_time_resolver = CachingDataTimeResolver(
asset_graph_view.get_inner_queryer_for_back_compat()
)
evaluator = AutomationConditionEvaluator(
asset_graph=asset_graph,
asset_keys=asset_selection.resolve(asset_graph),
asset_graph_view=asset_graph_view,
logger=logging.getLogger("dagster.automation_condition_tester"),
cursor=cursor or AssetDaemonCursor.empty(),
data_time_resolver=data_time_resolver,
respect_materialization_data_versions=False,
auto_materialize_run_tags={},
)
results, requested_asset_partitions = evaluator.evaluate()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely should be a goal to have this return a range representation soon. cc: @smackesey

cursor = AssetDaemonCursor(
evaluation_id=0,
last_observe_request_timestamp_by_asset_key={},
previous_evaluation_state=None,
previous_condition_cursors=[result.get_new_cursor() for result in results],
)

return EvaluateAutomationConditionsResult(
cursor=cursor,
requested_asset_partitions=requested_asset_partitions,
)
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Optional
from typing import TYPE_CHECKING, Optional

from dagster._annotations import experimental
from dagster._core.definitions.auto_materialize_rule import AutoMaterializeRule
from dagster._serdes.serdes import whitelist_for_serdes
from dagster._utils.security import non_secure_md5_hash_str

from ..automation_condition import AutomationResult
from ..automation_context import AutomationContext
from .asset_condition import AssetCondition

if TYPE_CHECKING:
from ..automation_condition import AutomationResult
from ..automation_context import AutomationContext


@experimental
@whitelist_for_serdes
Expand All @@ -27,7 +29,7 @@ def get_unique_id(self, *, parent_unique_id: Optional[str], index: Optional[str]
def description(self) -> str:
return self.rule.description

def evaluate(self, context: AutomationContext) -> AutomationResult:
def evaluate(self, context: "AutomationContext") -> "AutomationResult":
context.logger.debug(f"Evaluating rule: {self.rule.to_snapshot()}")
# Allow for access to legacy context in legacy rule evaluation
evaluation_result = self.rule.evaluate_for_asset(context)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dagster import (
AssetSelection,
AutomationCondition,
Definitions,
HourlyPartitionsDefinition,
StaticPartitionsDefinition,
asset,
evaluate_automation_conditions,
)
from dagster._core.instance import DagsterInstance


@asset(
partitions_def=HourlyPartitionsDefinition("2020-01-01-00:00"),
auto_materialize_policy=AutomationCondition.eager().as_auto_materialize_policy(),
)
def hourly() -> None: ...


@asset(
partitions_def=StaticPartitionsDefinition(["a", "b", "c"]),
auto_materialize_policy=AutomationCondition.eager().as_auto_materialize_policy(),
)
def static() -> None: ...


@asset(
auto_materialize_policy=AutomationCondition.eager().as_auto_materialize_policy(),
)
def unpartitioned() -> None: ...


defs = Definitions(assets=[hourly, static, unpartitioned])


def test_basic_regular_defs() -> None:
instance = DagsterInstance.ephemeral()

result = evaluate_automation_conditions(
defs=defs,
asset_selection=AssetSelection.assets(unpartitioned),
instance=instance,
)
assert result.total_requested == 1

result = evaluate_automation_conditions(
defs=defs,
asset_selection=AssetSelection.assets(unpartitioned),
instance=instance,
cursor=result.cursor,
)
assert result.total_requested == 0


def test_basic_assets_defs() -> None:
instance = DagsterInstance.ephemeral()

result = evaluate_automation_conditions(defs=[unpartitioned], instance=instance)
assert result.total_requested == 1

result = evaluate_automation_conditions(
defs=[unpartitioned], instance=instance, cursor=result.cursor
)
assert result.total_requested == 0