Skip to content

Commit

Permalink
Create AutomationConditionTester
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Jun 4, 2024
1 parent 9d831d9 commit 8aba3a2
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 4 deletions.
1 change: 1 addition & 0 deletions python_modules/dagster/dagster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
from dagster._core.definitions.declarative_automation import (
AssetCondition as AssetCondition,
AutomationCondition as AutomationCondition,
AutomationConditionTester as AutomationConditionTester,
)
from dagster._core.definitions.decorators.asset_check_decorator import (
asset_check as asset_check,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .automation_condition import AutomationCondition as AutomationCondition
from .automation_condition_tester import AutomationConditionTester as AutomationConditionTester
from .legacy import RuleCondition as RuleCondition
from .legacy.asset_condition import AssetCondition as AssetCondition
from .operands import (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import datetime
import logging
from collections import defaultdict
from functools import cached_property
from typing import AbstractSet, Mapping, Optional, Sequence

import mock
import pendulum

from dagster._core.asset_graph_view.asset_graph_view import AssetGraphView
from dagster._core.definitions.asset_daemon_cursor import AssetDaemonCursor
from dagster._core.definitions.asset_key import AssetKey
from dagster._core.definitions.asset_selection import AssetSelection
from dagster._core.definitions.data_time import CachingDataTimeResolver
from dagster._core.definitions.declarative_automation.automation_condition_evaluator import (
AutomationConditionEvaluator,
)
from dagster._core.definitions.declarative_automation.serialized_objects import (
AssetConditionEvaluation,
)
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.events import AssetKeyPartitionKey, AssetMaterialization
from dagster._core.instance import DagsterInstance
from dagster._utils.log import create_console_logger


class AutomationConditionTesterResult:
def __init__(self, requested_asset_partitions: AbstractSet[AssetKeyPartitionKey]):
self._requested_asset_partitions = requested_asset_partitions

@cached_property
def _requested_partitions_by_asset_key(self) -> Mapping[AssetKey, AbstractSet[Optional[str]]]:
mapping = defaultdict(set)
for asset_partition in self._requested_asset_partitions:
mapping[asset_partition.asset_key].add(asset_partition.partition_key)
return mapping

@property
def total_requested(self) -> int:
"""Returns the total number of asset partitions requested during this evaluation."""
return len(self._requested_asset_partitions)

def get_requested_partitions(self, asset_key: AssetKey) -> AbstractSet[Optional[str]]:
"""Returns the specific partition keys requested for the given asset during this evaluation."""
return self._requested_partitions_by_asset_key[asset_key]

def get_num_requested(self, asset_key: AssetKey) -> int:
"""Returns the number of asset partitions requested for the given asset during this evaluation."""
return len(self.get_requested_partitions(asset_key))


class AutomationConditionTester:
def __init__(
self,
defs: Definitions,
asset_selection: AssetSelection = AssetSelection.all(),
current_time: Optional[datetime.datetime] = None,
):
self._defs = defs
self._asset_selection = asset_selection
self._current_time = current_time or pendulum.now("UTC")
self._instance = DagsterInstance.ephemeral()
self._cursor = AssetDaemonCursor.empty()
self._logger = create_console_logger("dagster.automation", logging.DEBUG)

def set_current_time(self, dt: datetime.datetime) -> None:
self._current_time = dt

def add_materializations(
self, asset_key: AssetKey, partitions: Optional[Sequence[str]] = None
) -> None:
with mock.patch("time.time", new=lambda: self._current_time.timestamp()):
for partition in partitions or {None}:
self._instance.report_runless_asset_event(
AssetMaterialization(asset_key=asset_key, partition=partition)
)

def evaluate(self) -> AutomationConditionTesterResult:
"""Evaluates the AutomationConditions of all provided assets."""
asset_graph_view = AssetGraphView.for_test(
defs=self._defs,
instance=self._instance,
effective_dt=self._current_time,
last_event_id=self._instance.event_log_storage.get_maximum_record_id(),
)
asset_graph = self._defs.get_asset_graph()
data_time_resolver = CachingDataTimeResolver(
asset_graph_view.get_inner_queryer_for_back_compat()
)
evaluator = AutomationConditionEvaluator(
asset_graph=asset_graph,
asset_keys=self._asset_selection.resolve(asset_graph),
asset_graph_view=asset_graph_view,
logger=self._logger,
cursor=self._cursor,
data_time_resolver=data_time_resolver,
respect_materialization_data_versions=False,
auto_materialize_run_tags={},
)
results, requested_asset_partitions = evaluator.evaluate()
self._cursor = AssetDaemonCursor(
evaluation_id=0,
last_observe_request_timestamp_by_asset_key={},
previous_evaluation_state=None,
previous_condition_cursors=[result.get_new_cursor() for result in results],
)

for result in results:
self._logger.info(f"Evaluation of {result.asset_key}:")
self._log_evaluation(result.serializable_evaluation)

return AutomationConditionTesterResult(requested_asset_partitions)

def _log_evaluation(self, evaluation: AssetConditionEvaluation, depth: int = 1) -> None:
msg = " " * depth
msg += f"{evaluation.condition_snapshot.description} "
msg += f"({evaluation.true_subset.size} true) "
msg += (
f"({(evaluation.end_timestamp or 0) - (evaluation.start_timestamp or 0):.2f} seconds)"
)
self._logger.info(msg)
for child in evaluation.child_evaluations:
self._log_evaluation(child, depth + 1)
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Optional
from typing import TYPE_CHECKING, Optional

from dagster._annotations import experimental
from dagster._core.definitions.auto_materialize_rule import AutoMaterializeRule
from dagster._serdes.serdes import whitelist_for_serdes
from dagster._utils.security import non_secure_md5_hash_str

from ..automation_condition import AutomationResult
from ..automation_context import AutomationContext
from .asset_condition import AssetCondition

if TYPE_CHECKING:
from ..automation_condition import AutomationResult
from ..automation_context import AutomationContext


@experimental
@whitelist_for_serdes
Expand All @@ -27,7 +29,7 @@ def get_unique_id(self, *, parent_unique_id: Optional[str], index: Optional[str]
def description(self) -> str:
return self.rule.description

def evaluate(self, context: AutomationContext) -> AutomationResult:
def evaluate(self, context: "AutomationContext") -> "AutomationResult":
context.logger.debug(f"Evaluating rule: {self.rule.to_snapshot()}")
# Allow for access to legacy context in legacy rule evaluation
evaluation_result = self.rule.evaluate_for_asset(context)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import datetime

from dagster import (
AssetSelection,
AutomationCondition,
AutomationConditionTester,
Definitions,
HourlyPartitionsDefinition,
StaticPartitionsDefinition,
asset,
)


@asset(
partitions_def=HourlyPartitionsDefinition("2020-01-01-00:00"),
auto_materialize_policy=AutomationCondition.eager().as_auto_materialize_policy(),
)
def hourly() -> None: ...


@asset(
partitions_def=StaticPartitionsDefinition(["a", "b", "c"]),
auto_materialize_policy=AutomationCondition.eager().as_auto_materialize_policy(),
)
def static() -> None: ...


@asset(
auto_materialize_policy=AutomationCondition.eager().as_auto_materialize_policy(),
)
def unpartitioned() -> None: ...


defs = Definitions(assets=[hourly, static, unpartitioned])


def test_current_time_manipulation() -> None:
tester = AutomationConditionTester(
defs=defs,
asset_selection=AssetSelection.assets(hourly),
current_time=datetime.datetime(2020, 2, 2),
)

result = tester.evaluate()
assert result.total_requested == 1
assert result.get_requested_partitions(hourly.key) == {"2020-02-01-23:00"}

tester.set_current_time(datetime.datetime(3005, 5, 5))
result = tester.evaluate()
assert result.total_requested == 1
assert result.get_requested_partitions(hourly.key) == {"3005-05-04-23:00"}

tester.add_materializations(hourly.key, ["3005-05-03-23:00"])
result = tester.evaluate()
assert result.total_requested == 0
assert result.get_requested_partitions(hourly.key) == set()

0 comments on commit 8aba3a2

Please sign in to comment.