Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addl artf testing #2256

Merged
merged 10 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _compute_array_job_index() -> int:
def _outputs_interface(self) -> Dict[Any, Variable]:
"""
We override this method from PythonTask because the dispatch_execute method uses this
interface to construct outputs. Each instance of an container_array task will however produce outputs
interface to construct outputs. Each instance of a container_array task will however produce outputs
according to the underlying run_task interface and the array plugin handler will actually create a collection
from these individual outputs as the final output value.
"""
Expand Down
9 changes: 3 additions & 6 deletions flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
self.partitions: Optional[Partitions] = None
self.time_partition: Optional[TimePartition] = None

# todo: add time partition arg hint
def __call__(self, *args, **kwargs):
return self.bind_partitions(*args, **kwargs)

Expand Down Expand Up @@ -80,12 +79,12 @@
del kwargs[TIME_PARTITION_KWARG]
else:
# If user has not set time partition,
if self.artifact.time_partitioned:
if self.artifact.time_partitioned and self.time_partition is None:
logger.debug(f"Time partition not bound for {self.artifact.name}, setting to dynamic binding.")
self.time_partition = TimePartition(value=DYNAMIC_INPUT_BINDING)

if len(kwargs) > 0 and (self.artifact.partition_keys and len(self.artifact.partition_keys) > 0):
p = Partitions(None)
if self.artifact.partition_keys and len(self.artifact.partition_keys) > 0:
p = self.partitions or Partitions(None)

Check warning on line 87 in flytekit/core/artifact.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/artifact.py#L87

Added line #L87 was not covered by tests
# k is the partition key, v should be static, or an input to the task or workflow
for k, v in kwargs.items():
if not self.artifact.partition_keys or k not in self.artifact.partition_keys:
Expand All @@ -103,8 +102,6 @@
p.partitions[k] = Partition(value=DYNAMIC_INPUT_BINDING, name=k)
# Given the context, shouldn't need to set further reference_artifacts.
self.partitions = p
else:
logger.debug(f"No remaining partition keys for {self.artifact.name}")

return self

Expand Down
7 changes: 0 additions & 7 deletions flytekit/remote/remote_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,6 @@ def _upload_chunk(self, final=False):
self.buffer.seek(0)
data = self.buffer.read()

# h = hashlib.md5()
# h.update(data)
# md5 = h.digest()
# l = len(data)
#
# headers = {"Content-Length": str(l), "Content-MD5": md5}

try:
res = self._remote.client.get_upload_signed_url(
self._remote.default_project,
Expand Down
147 changes: 129 additions & 18 deletions tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from typing_extensions import Annotated, get_args

from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.artifact import Artifact, Inputs
from flytekit.core.context_manager import FlyteContext, FlyteContextManager, OutputMetadataTracker
from flytekit.core.array_node_map_task import map_task
from flytekit.core.artifact import Artifact, Inputs, TimePartition
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, OutputMetadataTracker
from flytekit.core.interface import detect_artifact
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.task import task
Expand All @@ -25,7 +26,6 @@
if "pandas" not in sys.modules:
pytest.skip(reason="Requires pandas", allow_module_level=True)


default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="project",
Expand Down Expand Up @@ -81,6 +81,37 @@ def t1(
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == ""


def test_basic_multiple_call():
import pandas as pd

a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True)

@task
def t1(
b_value: str, dt: datetime.datetime
) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value)(time_partition=Inputs.dt)(a="manual")]:
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return df

assert a1_t_ab.time_partition.granularity == Granularity.DAY
entities = OrderedDict()
t1_s = get_serializable(entities, serialization_settings, t1)
assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2
p = t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None
assert (
t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition.granularity == art_id.Granularity.DAY
)
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition.value.input_binding.var == "dt"
assert p["b"].HasField("input_binding")
assert p["b"].input_binding.var == "b_value"
assert p["a"].HasField("static_value")
assert p["a"].static_value == "manual"
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.version == ""
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data"
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == ""


def test_args_getting():
a1 = Artifact(name="argstst")
a1_called = a1()
Expand Down Expand Up @@ -194,6 +225,40 @@ def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b
assert artifact_id.time_partition.value.time_value == proto_timestamp


def test_basic_dynamic_only_time():
# This test is to ensure the metadata tracking component works if the user only binds a time at run time.
import pandas as pd

ctx = FlyteContextManager.current_context()
# without this omt, the part that keeps track of dynamic partitions doesn't kick in.
omt = OutputMetadataTracker()
ctx = ctx.with_output_metadata_tracker(omt).build()

a1_t = Artifact(name="my_data", time_partitioned=True)

@task
def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t]:
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return a1_t.create_from(df, time_partition=dt)

entities = OrderedDict()
t1_s = get_serializable(entities, serialization_settings, t1)
assert not t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value
assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None

d = datetime.datetime(2021, 1, 1, 0, 0)
lm = TypeEngine.dict_to_literal_map(ctx, {"b_value": "my b value", "dt": d})
lm_outputs = t1.dispatch_execute(ctx, lm)
dyn_partition_encoded = lm_outputs.literals["o0"].metadata["_uap"]
artifact_id = art_id.ArtifactID()
artifact_id.ParseFromString(b64decode(dyn_partition_encoded.encode("utf-8")))
assert not artifact_id.partitions.value

proto_timestamp = Timestamp()
proto_timestamp.FromDatetime(d)
assert artifact_id.time_partition.value.time_value == proto_timestamp


def test_dynamic_with_extras():
import pandas as pd

Expand Down Expand Up @@ -239,20 +304,6 @@ def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b
assert o0.metadata["p2_metao0"] == "this is more extra information"


def test_basic_no_call():
import pandas as pd

a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True)

# raise an error because the user hasn't () the artifact
with pytest.raises(ValueError):

@task
def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab]:
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return df


def test_basic_option_a3():
import pandas as pd

Expand Down Expand Up @@ -437,6 +488,9 @@ def t2(
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return df

with pytest.raises(ValueError):
Artifact(partition_keys=["a", "b"], time_partitioned=True)


def test_dynamic_input_binding():
a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True)
Expand All @@ -461,11 +515,68 @@ def test_tp_granularity():
assert a1_t_b.time_partition.granularity == Granularity.MONTH

@task
def t1(b_value: str, dt: datetime.datetime) -> Annotated[int, a1_t_b(time_partition=Inputs.dt, b=Inputs.b_value)]:
def t1(b_value: str, dt: datetime.datetime) -> Annotated[int, a1_t_b(b=Inputs.b_value)(time_partition=Inputs.dt)]:
return 5

entities = OrderedDict()
spec = get_serializable(entities, serialization_settings, t1)
assert (
spec.template.interface.outputs["o0"].artifact_partial_id.time_partition.granularity == art_id.Granularity.MONTH
)


def test_map_doesnt_add_any_metadata():
# The base task only looks for items in the metadata tracker at the top level. This test is here to maintain
# that state for now, though we may want to revisit this.
import pandas as pd

ctx = FlyteContextManager.current_context()
# without this omt, the part that keeps track of dynamic partitions doesn't kick in.
omt = OutputMetadataTracker()
ctx = (
ctx.with_output_metadata_tracker(omt)
.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION))
.build()
)

a1_b = Artifact(name="my_data", partition_keys=["b"])

@task
def t1(b_value: str) -> Annotated[pd.DataFrame, a1_b]:
df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})
return a1_b.create_from(df, b="dynamic!")

mt1 = map_task(t1)
entities = OrderedDict()
mt1_s = get_serializable(entities, serialization_settings, mt1)
o0 = mt1_s.template.interface.outputs["o0"]
assert not o0.artifact_partial_id
lm = TypeEngine.dict_to_literal_map(
ctx, {"b_value": ["my b value 1", "my b value 2"]}, type_hints={"b_value": typing.List[str]}
)
lm_outputs = mt1.dispatch_execute(ctx, lm)
coll = lm_outputs.literals["o0"].collection.literals
assert not coll[0].metadata
assert not coll[1].metadata


def test_tp_math():
a = Artifact(name="test artifact", time_partitioned=True)
d = datetime.datetime(2063, 4, 5, 0, 0)
pt = Timestamp()
pt.FromDatetime(d)
tp = TimePartition(value=art_id.LabelValue(time_value=pt), granularity=Granularity.HOUR)
tp.reference_artifact = a
tp2 = tp + datetime.timedelta(days=1)
assert tp2.op == art_id.Operator.PLUS
assert tp2.other == datetime.timedelta(days=1)
assert tp2.granularity == Granularity.HOUR
assert tp2 is not tp

tp = TimePartition(value=art_id.LabelValue(time_value=pt), granularity=Granularity.HOUR)
tp.reference_artifact = a
tp2 = tp - datetime.timedelta(days=1)
assert tp2.op == art_id.Operator.MINUS
assert tp2.other == datetime.timedelta(days=1)
assert tp2.granularity == Granularity.HOUR
assert tp2 is not tp
Loading