Skip to content

Commit

Permalink
Addl artf testing (#2256)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Mar 12, 2024
1 parent 7144ae9 commit 61a3e95
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 30 deletions.
2 changes: 1 addition & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _compute_array_job_index() -> int:
def _outputs_interface(self) -> Dict[Any, Variable]:
"""
We override this method from PythonTask because the dispatch_execute method uses this
interface to construct outputs. Each instance of an container_array task will however produce outputs
interface to construct outputs. Each instance of a container_array task will however produce outputs
according to the underlying run_task interface and the array plugin handler will actually create a collection
from these individual outputs as the final output value.
"""
Expand Down
9 changes: 3 additions & 6 deletions flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self, a: Artifact):
self.partitions: Optional[Partitions] = None
self.time_partition: Optional[TimePartition] = None

# todo: add time partition arg hint
def __call__(self, *args, **kwargs):
return self.bind_partitions(*args, **kwargs)

Expand Down Expand Up @@ -80,12 +79,12 @@ def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification:
del kwargs[TIME_PARTITION_KWARG]
else:
# If user has not set time partition,
if self.artifact.time_partitioned:
if self.artifact.time_partitioned and self.time_partition is None:
logger.debug(f"Time partition not bound for {self.artifact.name}, setting to dynamic binding.")
self.time_partition = TimePartition(value=DYNAMIC_INPUT_BINDING)

if len(kwargs) > 0 and (self.artifact.partition_keys and len(self.artifact.partition_keys) > 0):
p = Partitions(None)
if self.artifact.partition_keys and len(self.artifact.partition_keys) > 0:
p = self.partitions or Partitions(None)
# k is the partition key, v should be static, or an input to the task or workflow
for k, v in kwargs.items():
if not self.artifact.partition_keys or k not in self.artifact.partition_keys:
Expand All @@ -103,8 +102,6 @@ def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification:
p.partitions[k] = Partition(value=DYNAMIC_INPUT_BINDING, name=k)
# Given the context, shouldn't need to set further reference_artifacts.
self.partitions = p
else:
logger.debug(f"No remaining partition keys for {self.artifact.name}")

return self

Expand Down
8 changes: 5 additions & 3 deletions flytekit/core/artifact_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Dict, Optional

from flyteidl.core.artifact_id_pb2 import LabelValue, Partitions, TimePartition
from flyteidl.core.artifact_id_pb2 import Granularity, LabelValue, Partitions, TimePartition
from google.protobuf.timestamp_pb2 import Timestamp


Expand All @@ -12,11 +12,13 @@ def idl_partitions_from_dict(p: Optional[Dict[str, str]] = None) -> Optional[Par
return None


def idl_time_partition_from_datetime(tp: Optional[datetime] = None) -> Optional[TimePartition]:
def idl_time_partition_from_datetime(
tp: Optional[datetime] = None, time_partition_granularity: Optional[Granularity] = None
) -> Optional[TimePartition]:
if tp:
t = Timestamp()
t.FromDatetime(tp)
lv = LabelValue(time_value=t)
return TimePartition(value=lv)
return TimePartition(value=lv, granularity=time_partition_granularity)

return None
4 changes: 3 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,9 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte
if om.dynamic_partitions or om.time_partition:
a = art_id.ArtifactID(
partitions=idl_partitions_from_dict(om.dynamic_partitions),
time_partition=idl_time_partition_from_datetime(om.time_partition),
time_partition=idl_time_partition_from_datetime(
om.time_partition, om.artifact.time_partition_granularity
),
)
s = a.SerializeToString()
encoded = b64encode(s).decode("utf-8")
Expand Down
152 changes: 133 additions & 19 deletions tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from typing_extensions import Annotated, get_args

from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.artifact import Artifact, Inputs
from flytekit.core.context_manager import FlyteContext, FlyteContextManager, OutputMetadataTracker
from flytekit.core.array_node_map_task import map_task
from flytekit.core.artifact import Artifact, Inputs, TimePartition
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, OutputMetadataTracker
from flytekit.core.interface import detect_artifact
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.task import task
Expand All @@ -25,7 +26,6 @@
if "pandas" not in sys.modules:
pytest.skip(reason="Requires pandas", allow_module_level=True)


default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="project",
Expand Down Expand Up @@ -81,6 +81,37 @@ def t1(
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == ""


def test_basic_multiple_call():
import pandas as pd

a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True)

@task
def t1(
b_value: str, dt: datetime.datetime
) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value)(time_partition=Inputs.dt)(a="manual")]:
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return df

assert a1_t_ab.time_partition.granularity == Granularity.DAY
entities = OrderedDict()
t1_s = get_serializable(entities, serialization_settings, t1)
assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2
p = t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None
assert (
t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition.granularity == art_id.Granularity.DAY
)
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition.value.input_binding.var == "dt"
assert p["b"].HasField("input_binding")
assert p["b"].input_binding.var == "b_value"
assert p["a"].HasField("static_value")
assert p["a"].static_value == "manual"
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.version == ""
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data"
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == ""


def test_args_getting():
a1 = Artifact(name="argstst")
a1_called = a1()
Expand Down Expand Up @@ -162,7 +193,9 @@ def test_basic_dynamic():
omt = OutputMetadataTracker()
ctx = ctx.with_output_metadata_tracker(omt).build()

a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True)
a1_t_ab = Artifact(
name="my_data", partition_keys=["a", "b"], time_partitioned=True, time_partition_granularity=Granularity.MONTH
)

@task
def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value)]:
Expand Down Expand Up @@ -192,6 +225,41 @@ def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b
proto_timestamp = Timestamp()
proto_timestamp.FromDatetime(d)
assert artifact_id.time_partition.value.time_value == proto_timestamp
assert artifact_id.time_partition.granularity == Granularity.MONTH


def test_basic_dynamic_only_time():
# This test is to ensure the metadata tracking component works if the user only binds a time at run time.
import pandas as pd

ctx = FlyteContextManager.current_context()
# without this omt, the part that keeps track of dynamic partitions doesn't kick in.
omt = OutputMetadataTracker()
ctx = ctx.with_output_metadata_tracker(omt).build()

a1_t = Artifact(name="my_data", time_partitioned=True)

@task
def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t]:
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return a1_t.create_from(df, time_partition=dt)

entities = OrderedDict()
t1_s = get_serializable(entities, serialization_settings, t1)
assert not t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None

d = datetime.datetime(2021, 1, 1, 0, 0)
lm = TypeEngine.dict_to_literal_map(ctx, {"b_value": "my b value", "dt": d})
lm_outputs = t1.dispatch_execute(ctx, lm)
dyn_partition_encoded = lm_outputs.literals["o0"].metadata["_uap"]
artifact_id = art_id.ArtifactID()
artifact_id.ParseFromString(b64decode(dyn_partition_encoded.encode("utf-8")))
assert not artifact_id.partitions.value

proto_timestamp = Timestamp()
proto_timestamp.FromDatetime(d)
assert artifact_id.time_partition.value.time_value == proto_timestamp


def test_dynamic_with_extras():
Expand Down Expand Up @@ -239,20 +307,6 @@ def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b
assert o0.metadata["p2_metao0"] == "this is more extra information"


def test_basic_no_call():
import pandas as pd

a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True)

# raise an error because the user hasn't () the artifact
with pytest.raises(ValueError):

@task
def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab]:
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return df


def test_basic_option_a3():
import pandas as pd

Expand Down Expand Up @@ -437,6 +491,9 @@ def t2(
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return df

with pytest.raises(ValueError):
Artifact(partition_keys=["a", "b"], time_partitioned=True)


def test_dynamic_input_binding():
a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True)
Expand All @@ -461,11 +518,68 @@ def test_tp_granularity():
assert a1_t_b.time_partition.granularity == Granularity.MONTH

@task
def t1(b_value: str, dt: datetime.datetime) -> Annotated[int, a1_t_b(time_partition=Inputs.dt, b=Inputs.b_value)]:
def t1(b_value: str, dt: datetime.datetime) -> Annotated[int, a1_t_b(b=Inputs.b_value)(time_partition=Inputs.dt)]:
return 5

entities = OrderedDict()
spec = get_serializable(entities, serialization_settings, t1)
assert (
spec.template.interface.outputs["o0"].artifact_partial_id.time_partition.granularity == art_id.Granularity.MONTH
)


def test_map_doesnt_add_any_metadata():
# The base task only looks for items in the metadata tracker at the top level. This test is here to maintain
# that state for now, though we may want to revisit this.
import pandas as pd

ctx = FlyteContextManager.current_context()
# without this omt, the part that keeps track of dynamic partitions doesn't kick in.
omt = OutputMetadataTracker()
ctx = (
ctx.with_output_metadata_tracker(omt)
.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION))
.build()
)

a1_b = Artifact(name="my_data", partition_keys=["b"])

@task
def t1(b_value: str) -> Annotated[pd.DataFrame, a1_b]:
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return a1_b.create_from(df, b="dynamic!")

mt1 = map_task(t1)
entities = OrderedDict()
mt1_s = get_serializable(entities, serialization_settings, mt1)
o0 = mt1_s.template.interface.outputs["o0"]
assert not o0.artifact_partial_id
lm = TypeEngine.dict_to_literal_map(
ctx, {"b_value": ["my b value 1", "my b value 2"]}, type_hints={"b_value": typing.List[str]}
)
lm_outputs = mt1.dispatch_execute(ctx, lm)
coll = lm_outputs.literals["o0"].collection.literals
assert not coll[0].metadata
assert not coll[1].metadata


def test_tp_math():
a = Artifact(name="test artifact", time_partitioned=True)
d = datetime.datetime(2063, 4, 5, 0, 0)
pt = Timestamp()
pt.FromDatetime(d)
tp = TimePartition(value=art_id.LabelValue(time_value=pt), granularity=Granularity.HOUR)
tp.reference_artifact = a
tp2 = tp + datetime.timedelta(days=1)
assert tp2.op == art_id.Operator.PLUS
assert tp2.other == datetime.timedelta(days=1)
assert tp2.granularity == Granularity.HOUR
assert tp2 is not tp

tp = TimePartition(value=art_id.LabelValue(time_value=pt), granularity=Granularity.HOUR)
tp.reference_artifact = a
tp2 = tp - datetime.timedelta(days=1)
assert tp2.op == art_id.Operator.MINUS
assert tp2.other == datetime.timedelta(days=1)
assert tp2.granularity == Granularity.HOUR
assert tp2 is not tp

0 comments on commit 61a3e95

Please sign in to comment.