From 61a3e9554df52a716b810370b9486d96780af1e3 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 12 Mar 2024 12:07:24 -0700 Subject: [PATCH] Addl artf testing (#2256) Signed-off-by: Yee Hing Tong --- flytekit/core/array_node_map_task.py | 2 +- flytekit/core/artifact.py | 9 +- flytekit/core/artifact_utils.py | 8 +- flytekit/core/base_task.py | 4 +- tests/flytekit/unit/core/test_artifacts.py | 152 ++++++++++++++++++--- 5 files changed, 145 insertions(+), 30 deletions(-) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 5d10cadd45..7f4a837644 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -242,7 +242,7 @@ def _compute_array_job_index() -> int: def _outputs_interface(self) -> Dict[Any, Variable]: """ We override this method from PythonTask because the dispatch_execute method uses this - interface to construct outputs. Each instance of an container_array task will however produce outputs + interface to construct outputs. Each instance of a container_array task will however produce outputs according to the underlying run_task interface and the array plugin handler will actually create a collection from these individual outputs as the final output value. """ diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index 5228c2b701..742270a9ae 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -48,7 +48,6 @@ def __init__(self, a: Artifact): self.partitions: Optional[Partitions] = None self.time_partition: Optional[TimePartition] = None - # todo: add time partition arg hint def __call__(self, *args, **kwargs): return self.bind_partitions(*args, **kwargs) @@ -80,12 +79,12 @@ def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification: del kwargs[TIME_PARTITION_KWARG] else: # If user has not set time partition, - if self.artifact.time_partitioned: + if self.artifact.time_partitioned and self.time_partition is None: logger.debug(f"Time partition not bound for {self.artifact.name}, setting to dynamic binding.") self.time_partition = TimePartition(value=DYNAMIC_INPUT_BINDING) - if len(kwargs) > 0 and (self.artifact.partition_keys and len(self.artifact.partition_keys) > 0): - p = Partitions(None) + if self.artifact.partition_keys and len(self.artifact.partition_keys) > 0: + p = self.partitions or Partitions(None) # k is the partition key, v should be static, or an input to the task or workflow for k, v in kwargs.items(): if not self.artifact.partition_keys or k not in self.artifact.partition_keys: @@ -103,8 +102,6 @@ def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification: p.partitions[k] = Partition(value=DYNAMIC_INPUT_BINDING, name=k) # Given the context, shouldn't need to set further reference_artifacts. self.partitions = p - else: - logger.debug(f"No remaining partition keys for {self.artifact.name}") return self diff --git a/flytekit/core/artifact_utils.py b/flytekit/core/artifact_utils.py index e8edbb19d9..e8ba247ecf 100644 --- a/flytekit/core/artifact_utils.py +++ b/flytekit/core/artifact_utils.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Dict, Optional -from flyteidl.core.artifact_id_pb2 import LabelValue, Partitions, TimePartition +from flyteidl.core.artifact_id_pb2 import Granularity, LabelValue, Partitions, TimePartition from google.protobuf.timestamp_pb2 import Timestamp @@ -12,11 +12,13 @@ def idl_partitions_from_dict(p: Optional[Dict[str, str]] = None) -> Optional[Par return None -def idl_time_partition_from_datetime(tp: Optional[datetime] = None) -> Optional[TimePartition]: +def idl_time_partition_from_datetime( + tp: Optional[datetime] = None, time_partition_granularity: Optional[Granularity] = None +) -> Optional[TimePartition]: if tp: t = Timestamp() t.FromDatetime(tp) lv = LabelValue(time_value=t) - return TimePartition(value=lv) + return TimePartition(value=lv, granularity=time_partition_granularity) return None diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 1842a9957f..04d1d34c02 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -601,7 +601,9 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte if om.dynamic_partitions or om.time_partition: a = art_id.ArtifactID( partitions=idl_partitions_from_dict(om.dynamic_partitions), - time_partition=idl_time_partition_from_datetime(om.time_partition), + time_partition=idl_time_partition_from_datetime( + om.time_partition, om.artifact.time_partition_granularity + ), ) s = a.SerializeToString() encoded = b64encode(s).decode("utf-8") diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index 72280b6ffc..ea3734f8aa 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -12,8 +12,9 @@ from typing_extensions import Annotated, get_args from flytekit.configuration import Image, ImageConfig, SerializationSettings -from flytekit.core.artifact import Artifact, Inputs -from flytekit.core.context_manager import FlyteContext, FlyteContextManager, OutputMetadataTracker +from flytekit.core.array_node_map_task import map_task +from flytekit.core.artifact import Artifact, Inputs, TimePartition +from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, OutputMetadataTracker from flytekit.core.interface import detect_artifact from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task @@ -25,7 +26,6 @@ if "pandas" not in sys.modules: pytest.skip(reason="Requires pandas", allow_module_level=True) - default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", @@ -81,6 +81,37 @@ def t1( assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == "" +def test_basic_multiple_call(): + import pandas as pd + + a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True) + + @task + def t1( + b_value: str, dt: datetime.datetime + ) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value)(time_partition=Inputs.dt)(a="manual")]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df + + assert a1_t_ab.time_partition.granularity == Granularity.DAY + entities = OrderedDict() + t1_s = get_serializable(entities, serialization_settings, t1) + assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2 + p = t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None + assert ( + t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition.granularity == art_id.Granularity.DAY + ) + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition.value.input_binding.var == "dt" + assert p["b"].HasField("input_binding") + assert p["b"].input_binding.var == "b_value" + assert p["a"].HasField("static_value") + assert p["a"].static_value == "manual" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.version == "" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == "" + + def test_args_getting(): a1 = Artifact(name="argstst") a1_called = a1() @@ -162,7 +193,9 @@ def test_basic_dynamic(): omt = OutputMetadataTracker() ctx = ctx.with_output_metadata_tracker(omt).build() - a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True) + a1_t_ab = Artifact( + name="my_data", partition_keys=["a", "b"], time_partitioned=True, time_partition_granularity=Granularity.MONTH + ) @task def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value)]: @@ -192,6 +225,41 @@ def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b proto_timestamp = Timestamp() proto_timestamp.FromDatetime(d) assert artifact_id.time_partition.value.time_value == proto_timestamp + assert artifact_id.time_partition.granularity == Granularity.MONTH + + +def test_basic_dynamic_only_time(): + # This test is to ensure the metadata tracking component works if the user only binds a time at run time. + import pandas as pd + + ctx = FlyteContextManager.current_context() + # without this omt, the part that keeps track of dynamic partitions doesn't kick in. + omt = OutputMetadataTracker() + ctx = ctx.with_output_metadata_tracker(omt).build() + + a1_t = Artifact(name="my_data", time_partitioned=True) + + @task + def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return a1_t.create_from(df, time_partition=dt) + + entities = OrderedDict() + t1_s = get_serializable(entities, serialization_settings, t1) + assert not t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None + + d = datetime.datetime(2021, 1, 1, 0, 0) + lm = TypeEngine.dict_to_literal_map(ctx, {"b_value": "my b value", "dt": d}) + lm_outputs = t1.dispatch_execute(ctx, lm) + dyn_partition_encoded = lm_outputs.literals["o0"].metadata["_uap"] + artifact_id = art_id.ArtifactID() + artifact_id.ParseFromString(b64decode(dyn_partition_encoded.encode("utf-8"))) + assert not artifact_id.partitions.value + + proto_timestamp = Timestamp() + proto_timestamp.FromDatetime(d) + assert artifact_id.time_partition.value.time_value == proto_timestamp def test_dynamic_with_extras(): @@ -239,20 +307,6 @@ def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b assert o0.metadata["p2_metao0"] == "this is more extra information" -def test_basic_no_call(): - import pandas as pd - - a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True) - - # raise an error because the user hasn't () the artifact - with pytest.raises(ValueError): - - @task - def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab]: - df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) - return df - - def test_basic_option_a3(): import pandas as pd @@ -437,6 +491,9 @@ def t2( df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) return df + with pytest.raises(ValueError): + Artifact(partition_keys=["a", "b"], time_partitioned=True) + def test_dynamic_input_binding(): a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True) @@ -461,7 +518,7 @@ def test_tp_granularity(): assert a1_t_b.time_partition.granularity == Granularity.MONTH @task - def t1(b_value: str, dt: datetime.datetime) -> Annotated[int, a1_t_b(time_partition=Inputs.dt, b=Inputs.b_value)]: + def t1(b_value: str, dt: datetime.datetime) -> Annotated[int, a1_t_b(b=Inputs.b_value)(time_partition=Inputs.dt)]: return 5 entities = OrderedDict() @@ -469,3 +526,60 @@ def t1(b_value: str, dt: datetime.datetime) -> Annotated[int, a1_t_b(time_partit assert ( spec.template.interface.outputs["o0"].artifact_partial_id.time_partition.granularity == art_id.Granularity.MONTH ) + + +def test_map_doesnt_add_any_metadata(): + # The base task only looks for items in the metadata tracker at the top level. This test is here to maintain + # that state for now, though we may want to revisit this. + import pandas as pd + + ctx = FlyteContextManager.current_context() + # without this omt, the part that keeps track of dynamic partitions doesn't kick in. + omt = OutputMetadataTracker() + ctx = ( + ctx.with_output_metadata_tracker(omt) + .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION)) + .build() + ) + + a1_b = Artifact(name="my_data", partition_keys=["b"]) + + @task + def t1(b_value: str) -> Annotated[pd.DataFrame, a1_b]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return a1_b.create_from(df, b="dynamic!") + + mt1 = map_task(t1) + entities = OrderedDict() + mt1_s = get_serializable(entities, serialization_settings, mt1) + o0 = mt1_s.template.interface.outputs["o0"] + assert not o0.artifact_partial_id + lm = TypeEngine.dict_to_literal_map( + ctx, {"b_value": ["my b value 1", "my b value 2"]}, type_hints={"b_value": typing.List[str]} + ) + lm_outputs = mt1.dispatch_execute(ctx, lm) + coll = lm_outputs.literals["o0"].collection.literals + assert not coll[0].metadata + assert not coll[1].metadata + + +def test_tp_math(): + a = Artifact(name="test artifact", time_partitioned=True) + d = datetime.datetime(2063, 4, 5, 0, 0) + pt = Timestamp() + pt.FromDatetime(d) + tp = TimePartition(value=art_id.LabelValue(time_value=pt), granularity=Granularity.HOUR) + tp.reference_artifact = a + tp2 = tp + datetime.timedelta(days=1) + assert tp2.op == art_id.Operator.PLUS + assert tp2.other == datetime.timedelta(days=1) + assert tp2.granularity == Granularity.HOUR + assert tp2 is not tp + + tp = TimePartition(value=art_id.LabelValue(time_value=pt), granularity=Granularity.HOUR) + tp.reference_artifact = a + tp2 = tp - datetime.timedelta(days=1) + assert tp2.op == art_id.Operator.MINUS + assert tp2.other == datetime.timedelta(days=1) + assert tp2.granularity == Granularity.HOUR + assert tp2 is not tp