From 692b78c2fd1065846a8b9ff452c8174f0963db03 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 17 Jan 2024 18:15:46 -0800 Subject: [PATCH] use separate time partition in idl change https://github.com/flyteorg/flyte/pull/4737 Signed-off-by: Yee Hing Tong --- flytekit/core/artifact.py | 150 ++++++++++++--------- flytekit/models/literals.py | 4 +- flytekit/trigger.py | 4 +- setup.py | 2 +- tests/flytekit/unit/core/test_artifacts.py | 102 ++++++++------ tests/flytekit/unit/core/test_triggers.py | 59 ++++---- 6 files changed, 186 insertions(+), 135 deletions(-) diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index 866dd7f4ec..e8d02712a6 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -11,8 +11,7 @@ from flyteidl.artifact import artifacts_pb2 from flyteidl.core import artifact_id_pb2 as art_id from flyteidl.core.identifier_pb2 import TaskExecutionIdentifier, WorkflowExecutionIdentifier -from flyteidl.core.literals_pb2 import Literal -from flyteidl.core.types_pb2 import LiteralType +from google.protobuf.timestamp_pb2 import Timestamp from flytekit.loggers import logger from flytekit.models.literals import Literal @@ -51,12 +50,10 @@ class ArtifactIDSpecification(object): having a pointer to the main artifact. """ - def __init__( - self, a: Artifact, partitions: Optional[Partitions] = None, time_partition: Optional[TimePartition] = None - ): + def __init__(self, a: Artifact): self.artifact = a - self.partitions = partitions - self.time_partition = time_partition + self.partitions: Optional[Partitions] = None + self.time_partition: Optional[TimePartition] = None # todo: add time partition arg hint def __call__(self, *args, **kwargs): @@ -72,7 +69,9 @@ def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification: raise ValueError("Cannot bind time partition to non-time partitioned artifact") p = kwargs[TIME_PARTITION_KWARG] if isinstance(p, datetime.datetime): - self.time_partition = TimePartition(value=art_id.LabelValue(static_value=f"{p}")) + t = Timestamp() + t.FromDatetime(p) + self.time_partition = TimePartition(value=art_id.LabelValue(time_value=t)) elif isinstance(p, art_id.InputBindingData): self.time_partition = TimePartition(value=art_id.LabelValue(input_binding=p)) else: @@ -102,20 +101,30 @@ def to_partial_artifact_id(self) -> art_id.ArtifactID: # This function should only be called by transform_variable_map artifact_id = self.artifact.to_flyte_idl().artifact_id # Use the partitions from this object, but replacement is not allowed by protobuf, so generate new object - p = partitions_to_idl(self.partitions, self.time_partition) + p = partitions_to_idl(self.partitions) + tp = None + if self.artifact.time_partitioned: + if not self.time_partition: + raise ValueError( + f"Artifact {artifact_id.artifact_key} requires a time partition, but it hasn't been bound." + ) + tp = self.time_partition.to_flyte_idl() if self.artifact.partition_keys: required = len(self.artifact.partition_keys) - required += 1 if self.artifact.time_partitioned else 0 + # required += 1 if self.artifact.time_partitioned else 0 fulfilled = len(p.value) if p else 0 if required != fulfilled: raise ValueError( - f"Artifact {artifact_id.artifact_key} requires {required} partitions, but only {fulfilled} are bound." + f"Artifact {artifact_id.artifact_key} requires {required} partitions, but only {fulfilled} are " + f"bound." ) artifact_id = art_id.ArtifactID( artifact_key=artifact_id.artifact_key, partitions=p, - version=artifact_id.version, + time_partition=tp, + version=artifact_id.version, # this should almost never be set since setting it + # hardcodes the query to one version ) return artifact_id @@ -155,9 +164,7 @@ def __init__( tag: Optional[str] = None, ): if not name: - raise ValueError(f"Cannot create query without name") - if partitions and partitions.partitions and TIME_PARTITION in partitions.partitions: - raise ValueError(f"Cannot use 'ds' as a partition name, just use time partition") + raise ValueError("Cannot create query without name") # So normally, if you just do MyData.query(partitions={"region": "{{ inputs.region }}"}), it will just # use the input value to fill in the partition. But if you do @@ -203,11 +210,15 @@ def to_flyte_idl( ) return aq - p = partitions_to_idl(self.partitions, self.time_partition, bindings) + p = partitions_to_idl(self.partitions, bindings) + tp = None + if self.time_partition: + tp = self.time_partition.to_flyte_idl(bindings) i = art_id.ArtifactID( artifact_key=ak, partitions=p, + time_partition=tp, ) aq = art_id.ArtifactQuery( @@ -232,13 +243,15 @@ def __init__( other: Optional[timedelta] = None, ): if isinstance(value, str): - value = art_id.LabelValue(static_value=value) + raise ValueError(f"value to a time partition shouldn't be a str {value}") elif isinstance(value, datetime.datetime): - value = art_id.LabelValue(static_value=f"{value}") + t = Timestamp() + t.FromDatetime(value) + value = art_id.LabelValue(time_value=t) elif isinstance(value, art_id.InputBindingData): value = art_id.LabelValue(input_binding=value) # else should already be a LabelValue or None - self.value = value + self.value: art_id.LabelValue = value self.op = op self.other = other self.reference_artifact: Optional[Artifact] = None @@ -253,15 +266,11 @@ def __sub__(self, other: timedelta) -> TimePartition: tp.reference_artifact = self.reference_artifact return tp - def truncate_to_day(self): - # raise NotImplementedError("Not implemented yet") - return self - - def get_idl_partitions_for_trigger(self, bindings: typing.List[Artifact]) -> art_id.Partitions: + def get_idl_partitions_for_trigger(self, bindings: typing.List[Artifact]) -> art_id.TimePartition: if not self.reference_artifact or (self.reference_artifact and self.reference_artifact not in bindings): # basically if there's no reference artifact, or if the reference artifact isn't # in the list of triggers, then treat it like normal. - return art_id.Partitions(value={TIME_PARTITION: self.value}) + return art_id.TimePartition(value=self.value) elif self.reference_artifact in bindings: idx = bindings.index(self.reference_artifact) transform = None @@ -270,55 +279,35 @@ def get_idl_partitions_for_trigger(self, bindings: typing.List[Artifact]) -> art lv = art_id.LabelValue( triggered_binding=art_id.ArtifactBindingData( index=idx, - partition_key=TIME_PARTITION, + bind_to_time_partition=True, transform=transform, ) ) - return art_id.Partitions(value={TIME_PARTITION: lv}) + return art_id.TimePartition(value=lv) # investigate if this happens, if not, remove. else logger.warning(f"Investigate - time partition in trigger with unhandled reference artifact {self}") - return art_id.Partitions(value={TIME_PARTITION: self.value}) + raise ValueError("Time partition reference artifact not found in ") + # return art_id.Partitions(value={TIME_PARTITION: self.value}) - def to_flyte_idl(self, bindings: Optional[typing.List[Artifact]] = None) -> Optional[art_id.Partitions]: + def to_flyte_idl(self, bindings: Optional[typing.List[Artifact]] = None) -> Optional[art_id.TimePartition]: if bindings and len(bindings) > 0: return self.get_idl_partitions_for_trigger(bindings) if not self.value: # This is only for triggers - the backend needs to know of the existence of a time partition - return art_id.Partitions(value={TIME_PARTITION: art_id.LabelValue(static_value="")}) - - return art_id.Partitions(value={TIME_PARTITION: self.value}) - + return art_id.TimePartition() -def merge_idl_partitions( - p_idl: Optional[art_id.Partitions], time_p_idl: Optional[art_id.Partitions] -) -> Optional[art_id.Partitions]: - if not p_idl and not time_p_idl: - return None - p = {} - if p_idl and p_idl.value: - p.update(p_idl.value) - if time_p_idl and time_p_idl.value: - p.update(time_p_idl.value) - - return art_id.Partitions(value=p) if p else None + return art_id.TimePartition(value=self.value) def partitions_to_idl( partitions: Optional[Partitions], - time_partition: Optional[TimePartition], bindings: Optional[typing.List[Artifact]] = None, ) -> Optional[art_id.Partitions]: - partition_idl = None if partitions: - partition_idl = partitions.to_flyte_idl(bindings) - - time_p_idl = None - if time_partition: - time_p_idl = time_partition.to_flyte_idl(bindings) + return partitions.to_flyte_idl(bindings) - merged = merge_idl_partitions(partition_idl, time_p_idl) - return merged + return None class Partition(object): @@ -378,7 +367,7 @@ def get_idl_partitions_for_trigger( if not v.reference_artifact or ( v.reference_artifact and v.reference_artifact is self.reference_artifact - and not v.reference_artifact in bindings + and v.reference_artifact not in bindings ): # consider changing condition to just check for static value p[k] = art_id.LabelValue(static_value=v.value.static_value) @@ -543,12 +532,16 @@ def time_partition(self) -> TimePartition: return self._time_partition def __str__(self): + tp_str = f" time partition={self.time_partition}\n" if self.time_partitioned else "" return ( f"Artifact: project={self.project}, domain={self.domain}, name={self.name}, version={self.version}\n" f" name={self.name}\n" f" partitions={self.partitions}\n" + f"{tp_str}" f" tags={self.tags}\n" - f" literal_type={self.literal_type}, literal={self.literal})" + f" literal_type=" + f"{self.literal_type}, " + f"literal={self.literal})" ) def __repr__(self): @@ -668,21 +661,30 @@ def as_artifact_id(self) -> art_id.ArtifactID: return self.to_flyte_idl().artifact_id def embed_as_query( - self, bindings: typing.List[Artifact], partition: Optional[str] = None, expr: Optional[str] = None + self, + bindings: typing.List[Artifact], + partition: Optional[str] = None, + bind_to_time_partition: Optional[bool] = None, + expr: Optional[str] = None, ) -> art_id.ArtifactQuery: """ This should only be called in the context of a Trigger :param bindings: The list of artifacts in trigger_on :param partition: Can embed a time partition + :param bind_to_time_partition: Set to true if you want to bind to a time partition :param expr: Only valid if there's a time partition. """ # Find self in the list, raises ValueError if not there. idx = bindings.index(self) aq = art_id.ArtifactQuery( binding=art_id.ArtifactBindingData( - index=idx, partition_key=partition, transform=str(expr) if expr and partition else None + index=idx, + partition_key=partition, + bind_to_time_partition=bind_to_time_partition, + transform=str(expr) if expr and (partition or bind_to_time_partition) else None, ) ) + return aq def to_flyte_idl(self) -> artifacts_pb2.Artifact: @@ -691,7 +693,8 @@ def to_flyte_idl(self) -> artifacts_pb2.Artifact: This is here instead of translator because it's in the interface, a relatively simple proto object that's exposed to the user. """ - p = partitions_to_idl(self.partitions, self.time_partition if self.time_partitioned else None) + p = partitions_to_idl(self.partitions) + tp = self.time_partition.to_flyte_idl() if self.time_partitioned else None return artifacts_pb2.Artifact( artifact_id=art_id.ArtifactID( @@ -702,6 +705,7 @@ def to_flyte_idl(self) -> artifacts_pb2.Artifact: ), version=self.version, partitions=p, + time_partition=tp, ), spec=artifacts_pb2.ArtifactSpec(), tags=self.tags, @@ -717,9 +721,18 @@ def as_create_request(self) -> artifacts_pb2.CreateArtifactRequest: value=self.literal, type=self.literal_type, ) - partitions = partitions_to_idl(self.partitions, self.time_partition) - tag = self.tags[0] if self.tags else None - return artifacts_pb2.CreateArtifactRequest(artifact_key=ak, spec=spec, partitions=partitions, tag=tag) + partitions = partitions_to_idl(self.partitions) + + tp = None + if self._time_partition: + tv = self.time_partition.value.time_value + if not tv: + raise Exception("missing time value") + tp = self.time_partition.value.time_value + + return artifacts_pb2.CreateArtifactRequest( + artifact_key=ak, spec=spec, partitions=partitions, time_partition_value=tp + ) @classmethod def from_flyte_idl(cls, pb2: artifacts_pb2.Artifact) -> Artifact: @@ -741,17 +754,22 @@ def from_flyte_idl(cls, pb2: artifacts_pb2.Artifact) -> Artifact: if len(pb2.artifact_id.partitions.value) > 0: # static values should be the only ones set since currently we don't from_flyte_idl # anything that's not a materialized artifact. - if TIME_PARTITION in pb2.artifact_id.partitions.value: - a._time_partition = TimePartition(pb2.artifact_id.partitions.value[TIME_PARTITION].static_value) - a._time_partition.reference_artifact = a + # if TIME_PARTITION in pb2.artifact_id.partitions.value: + # a._time_partition = TimePartition(pb2.artifact_id.partitions.value[TIME_PARTITION].static_value) + # a._time_partition.reference_artifact = a a._partitions = Partitions( partitions={ k: Partition(value=v, name=k) for k, v in pb2.artifact_id.partitions.value.items() - if k != TIME_PARTITION + # if k != TIME_PARTITION } ) a.partitions.reference_artifact = a + if pb2.artifact_id.HasField("time_partition"): + ts = pb2.artifact_id.time_partition.value.time_value + dt = ts.ToDatetime() + a._time_partition = TimePartition(dt) + a._time_partition.reference_artifact = a return a diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index f164ab7b25..0746e0f126 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -1,6 +1,6 @@ from datetime import datetime as _datetime from datetime import timezone as _timezone -from typing import Optional, Dict +from typing import Dict, Optional from flyteidl.core import literals_pb2 as _literals_pb2 from google.protobuf.struct_pb2 import Struct @@ -859,7 +859,7 @@ def __init__( collection: Optional[LiteralCollection] = None, map: Optional[LiteralMap] = None, hash: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, + metadata: Optional[Dict[str, str]] = None, ): """ This IDL message represents a literal value in the Flyte ecosystem. diff --git a/flytekit/trigger.py b/flytekit/trigger.py index 6623565d84..62d703f642 100644 --- a/flytekit/trigger.py +++ b/flytekit/trigger.py @@ -6,7 +6,7 @@ from flyteidl.core import identifier_pb2 as idl from flyteidl.core import interface_pb2 -from flytekit.core.artifact import TIME_PARTITION, Artifact, ArtifactQuery, Partition, TimePartition +from flytekit.core.artifact import Artifact, ArtifactQuery, Partition, TimePartition from flytekit.core.context_manager import FlyteContextManager from flytekit.core.launch_plan import LaunchPlan from flytekit.core.tracker import TrackedInstance @@ -77,7 +77,7 @@ def get_parameter_map( expr = None if v.op and v.other and isinstance(v.other, timedelta): expr = str(v.op) + isodate.duration_isoformat(v.other) - aq = v.reference_artifact.embed_as_query(self.triggers, TIME_PARTITION, expr) + aq = v.reference_artifact.embed_as_query(self.triggers, bind_to_time_partition=True, expr=expr) p = interface_pb2.Parameter(var=var, artifact_query=aq) elif isinstance(v, Partition): # The reason is that if we bind to arbitrary partitions, we'll have to start keeping track of types diff --git a/setup.py b/setup.py index fc1f76c84d..606849326a 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,3 @@ from setuptools import setup -setup() \ No newline at end of file +setup() diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index c83e5cb035..855cba2049 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -1,9 +1,9 @@ import datetime -import typing from collections import OrderedDict import pandas as pd import pytest +from flyteidl.core import artifact_id_pb2 as art_id from typing_extensions import Annotated from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -13,7 +13,6 @@ from flytekit.core.task import task from flytekit.core.workflow import workflow from flytekit.tools.translator import get_serializable -from flytekit.types.structured.structured_dataset import StructuredDataset default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( @@ -42,10 +41,10 @@ def t1( entities = OrderedDict() t1_s = get_serializable(entities, serialization_settings, t1) - assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 3 + assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2 p = t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value - assert p["ds"].HasField("input_binding") - assert p["ds"].input_binding.var == "dt" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition.value.input_binding.var == "dt" assert p["b"].HasField("input_binding") assert p["b"].input_binding.var == "b_value" assert p["a"].HasField("static_value") @@ -55,6 +54,43 @@ def t1( assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == "" +def test_basic_option_no_tp(): + a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"]) + assert not a1_t_ab.time_partitioned + + # trying to bind to a time partition when not so raises an error. + with pytest.raises(ValueError): + + @task + def t1x( + b_value: str, dt: datetime.datetime + ) -> Annotated[pd.DataFrame, a1_t_ab(time_partition=Inputs.dt, b=Inputs.b_value, a="manual")]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df + + @task + def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value, a="manual")]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df + + entities = OrderedDict() + t1_s = get_serializable(entities, serialization_settings, t1) + assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2 + p = t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.HasField("time_partition") is False + assert p["b"].HasField("input_binding") + + +def test_basic_option_hardcoded_tp(): + a1_t_ab = Artifact(name="my_data", time_partitioned=True) + + dt = datetime.datetime.strptime("04/05/2063", "%m/%d/%Y") + + id_spec = a1_t_ab(time_partition=dt) + assert id_spec.partitions is None + assert id_spec.time_partition.value.HasField("time_value") + + def test_basic_option_a(): a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True) @@ -67,10 +103,11 @@ def t1( entities = OrderedDict() t1_s = get_serializable(entities, serialization_settings, t1) - assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 3 + assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2 assert t1_s.template.interface.outputs["o0"].artifact_partial_id.version == "" assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data" assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == "" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None def test_basic_option_a2(): @@ -79,7 +116,7 @@ def test_basic_option_a2(): with pytest.raises(ValueError): @task - def t2(b_value: str) -> Annotated[pd.DataFrame, a2_ab(a=Inputs.b_value)]: + def t2x(b_value: str) -> Annotated[pd.DataFrame, a2_ab(a=Inputs.b_value)]: ... @task @@ -107,29 +144,6 @@ def t3(b_value: str) -> Annotated[pd.DataFrame, a3]: assert t3_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data3" -def test_controlling_aliases_when_running(): - task_alias = Artifact(name="task_artifact", tags=["latest"]) - wf_alias = Artifact(name="wf_artifact", tags=["my_v0.1.0"]) - - @task - def t1() -> Annotated[typing.Union[CustomReturn, Annotated[StructuredDataset, "avro"]], task_alias]: - return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]}) - - @workflow - def wf() -> Annotated[CustomReturn, wf_alias]: - u = t1() - return u - - entities = OrderedDict() - spec = get_serializable(entities, serialization_settings, t1) - tag = spec.template.interface.outputs["o0"].artifact_tag - assert tag.value.static_value == "latest" - - spec = get_serializable(entities, serialization_settings, wf) - tag = spec.template.interface.outputs["o0"].artifact_tag - assert tag.value.static_value == "my_v0.1.0" - - def test_query_basic(): aa = Artifact( name="ride_count_data", @@ -142,11 +156,10 @@ def test_query_basic(): dq_idl = data_query.to_flyte_idl() assert dq_idl.HasField("artifact_id") assert dq_idl.artifact_id.artifact_key.name == "ride_count_data" - assert len(dq_idl.artifact_id.partitions.value) == 2 - assert dq_idl.artifact_id.partitions.value["ds"].HasField("input_binding") - assert dq_idl.artifact_id.partitions.value["ds"].input_binding.var == "dt" + assert len(dq_idl.artifact_id.partitions.value) == 1 assert dq_idl.artifact_id.partitions.value["region"].HasField("input_binding") assert dq_idl.artifact_id.partitions.value["region"].input_binding.var == "blah" + assert dq_idl.artifact_id.time_partition.value.input_binding.var == "dt" def test_not_specified_behavior(): @@ -162,11 +175,12 @@ def test_not_specified_behavior(): assert wf_artifact_no_tag.partitions is None aq = wf_artifact_no_tag.query().to_flyte_idl() assert aq.artifact_id.HasField("partitions") is False + assert aq.artifact_id.HasField("time_partition") is False def test_artifact_as_promise_query(): # when artifact is partially specified, can be used as a query input - wf_artifact = Artifact(project="project1", domain="dev", name="wf_artifact", tags=["my_v0.1.0"]) + wf_artifact = Artifact(project="project1", domain="dev", name="wf_artifact") @task def t1(a: CustomReturn) -> CustomReturn: @@ -182,10 +196,9 @@ def wf(a: CustomReturn = wf_artifact.query()): lp = LaunchPlan.get_default_launch_plan(ctx, wf) entities = OrderedDict() spec = get_serializable(entities, serialization_settings, lp) - assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.artifact_key.project == "project1" - assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.artifact_key.domain == "dev" - assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.artifact_key.name == "wf_artifact" - assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_tag.value.static_value == "my_v0.1.0" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.project == "project1" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.domain == "dev" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.name == "wf_artifact" def test_artifact_as_promise(): @@ -206,7 +219,6 @@ def wf2(a: CustomReturn = wf_artifact): lp = LaunchPlan.get_default_launch_plan(ctx, wf2) entities = OrderedDict() spec = get_serializable(entities, serialization_settings, lp) - x = spec.spec.default_inputs.parameters["a"] assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.project == "pro" assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.domain == "dom" assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.name == "key" @@ -214,3 +226,15 @@ def wf2(a: CustomReturn = wf_artifact): aq = wf_artifact.query().to_flyte_idl() assert aq.artifact_id.HasField("partitions") is True assert aq.artifact_id.partitions.value["region"].static_value == "LAX" + + +def test_partition_none(): + # confirm that we can distinguish between partitions being set to empty, and not being set + # though this is not currently used. + ak = art_id.ArtifactKey(project="p", domain="d", name="name") + no_partition = art_id.ArtifactID(artifact_key=ak, version="without_p") + assert not no_partition.HasField("partitions") + + p = art_id.Partitions() + with_partition = art_id.ArtifactID(artifact_key=ak, version="without_p", partitions=p) + assert with_partition.HasField("partitions") diff --git a/tests/flytekit/unit/core/test_triggers.py b/tests/flytekit/unit/core/test_triggers.py index f25920d649..512ecc6ce2 100644 --- a/tests/flytekit/unit/core/test_triggers.py +++ b/tests/flytekit/unit/core/test_triggers.py @@ -10,7 +10,7 @@ def test_basic_11(): - # This test would translate to + # This test translates to # Trigger(trigger_on=[hourlyArtifact], # inputs={"x": hourlyArtifact}) hourlyArtifact = Artifact( @@ -21,6 +21,8 @@ def test_basic_11(): aq_idl = hourlyArtifact.embed_as_query([hourlyArtifact]) assert aq_idl.HasField("binding") assert aq_idl.binding.index == 0 + assert aq_idl.binding.partition_key == "" + assert aq_idl.binding.bind_to_time_partition is False def test_basic_1(): @@ -37,10 +39,12 @@ def test_basic_1(): aq = hourlyArtifact.query(partitions={"region": "LAX"}) aq_idl = aq.to_flyte_idl([hourlyArtifact]) - assert aq_idl.artifact_id.partitions.value["ds"].HasField("triggered_binding") - assert aq_idl.artifact_id.partitions.value["ds"].triggered_binding.index == 0 + assert aq_idl.artifact_id.time_partition.value.HasField("triggered_binding") + assert aq_idl.artifact_id.time_partition.value.triggered_binding.index == 0 + assert aq_idl.artifact_id.time_partition.value.triggered_binding.bind_to_time_partition is True assert aq_idl.artifact_id.partitions.value["some_dim"].HasField("triggered_binding") assert aq_idl.artifact_id.partitions.value["some_dim"].triggered_binding.index == 0 + assert aq_idl.artifact_id.partitions.value["some_dim"].triggered_binding.bind_to_time_partition is False assert aq_idl.artifact_id.partitions.value["region"].static_value == "LAX" @@ -50,9 +54,11 @@ def test_basic_2(): aq = dailyArtifact.query(time_partition=dailyArtifact.time_partition - timedelta(days=1)) aq_idl = aq.to_flyte_idl([dailyArtifact]) x = aq_idl.artifact_id.partitions.value - assert aq_idl.artifact_id.partitions.value["ds"].triggered_binding.index == 0 - assert aq_idl.artifact_id.partitions.value["ds"].triggered_binding.partition_key == "ds" - assert aq_idl.artifact_id.partitions.value["ds"].triggered_binding.transform is not None + assert len(x) == 0 + assert aq_idl.artifact_id.time_partition.value.triggered_binding.index == 0 + assert aq_idl.artifact_id.time_partition.value.triggered_binding.HasField("partition_key") is False + assert aq_idl.artifact_id.time_partition.value.triggered_binding.bind_to_time_partition is True + assert aq_idl.artifact_id.time_partition.value.triggered_binding.transform is not None def test_big_trigger(): @@ -77,7 +83,7 @@ def test_big_trigger(): "other_daily_upstream": hourlyArtifact.query(partitions={"region": "LAX"}), "region": "SEA", # static value that will be passed as input "other_artifact": UnrelatedArtifact.query(time_partition=dailyArtifact.time_partition), - "other_artifact_2": UnrelatedArtifact.query(time_partition=hourlyArtifact.time_partition.truncate_to_day()), + "other_artifact_2": UnrelatedArtifact.query(time_partition=hourlyArtifact.time_partition), "other_artifact_3": UnrelatedTwo.query(partitions={"rgg": hourlyArtifact.partitions.region}), }, ) @@ -102,17 +108,16 @@ def my_workflow( ), ) assert not pm.parameters["today_upstream"].artifact_query.binding.partition_key + assert not pm.parameters["today_upstream"].artifact_query.binding.bind_to_time_partition assert not pm.parameters["today_upstream"].artifact_query.binding.transform assert pm.parameters["yesterday_upstream"].artifact_query == art_id.ArtifactQuery( artifact_id=art_id.ArtifactID( artifact_key=art_id.ArtifactKey(project=None, domain=None, name="daily_artifact"), - partitions=art_id.Partitions( - value={ - "ds": art_id.LabelValue( - triggered_binding=art_id.ArtifactBindingData(index=0, partition_key="ds", transform="-P1D") - ), - } + time_partition=art_id.TimePartition( + value=art_id.LabelValue( + triggered_binding=art_id.ArtifactBindingData(index=0, bind_to_time_partition=True, transform="-P1D") + ), ), ), ) @@ -120,9 +125,13 @@ def my_workflow( assert pm.parameters["other_daily_upstream"].artifact_query == art_id.ArtifactQuery( artifact_id=art_id.ArtifactID( artifact_key=art_id.ArtifactKey(project=None, domain=None, name="hourly_artifact"), + time_partition=art_id.TimePartition( + value=art_id.LabelValue( + triggered_binding=art_id.ArtifactBindingData(index=1, bind_to_time_partition=True) + ) + ), partitions=art_id.Partitions( value={ - "ds": art_id.LabelValue(triggered_binding=art_id.ArtifactBindingData(index=1, partition_key="ds")), "region": art_id.LabelValue(static_value="LAX"), } ), @@ -136,10 +145,10 @@ def my_workflow( assert pm.parameters["other_artifact"].artifact_query == art_id.ArtifactQuery( artifact_id=art_id.ArtifactID( artifact_key=art_id.ArtifactKey(project=None, domain=None, name="unrelated_artifact"), - partitions=art_id.Partitions( - value={ - "ds": art_id.LabelValue(triggered_binding=art_id.ArtifactBindingData(index=0, partition_key="ds")), - } + time_partition=art_id.TimePartition( + value=art_id.LabelValue( + triggered_binding=art_id.ArtifactBindingData(index=0, bind_to_time_partition=True) + ) ), ) ) @@ -147,10 +156,10 @@ def my_workflow( assert pm.parameters["other_artifact_2"].artifact_query == art_id.ArtifactQuery( artifact_id=art_id.ArtifactID( artifact_key=art_id.ArtifactKey(project=None, domain=None, name="unrelated_artifact"), - partitions=art_id.Partitions( - value={ - "ds": art_id.LabelValue(triggered_binding=art_id.ArtifactBindingData(index=1, partition_key="ds")), - } + time_partition=art_id.TimePartition( + value=art_id.LabelValue( + triggered_binding=art_id.ArtifactBindingData(index=1, bind_to_time_partition=True) + ) ), ) ) @@ -169,8 +178,8 @@ def my_workflow( ) idl_t = t.to_flyte_idl() - assert idl_t.triggers[0].partitions.value["ds"] is not None - assert idl_t.triggers[1].partitions.value["ds"] is not None + assert idl_t.triggers[0].HasField("time_partition") + assert idl_t.triggers[1].HasField("time_partition") # Test calling it to create the LaunchPlan object which adds to the global context @t @@ -208,7 +217,7 @@ def tst_wf( assert pm.parameters["today_upstream"].artifact_query == art_id.ArtifactQuery( binding=art_id.ArtifactBindingData( index=0, - partition_key="ds", + bind_to_time_partition=True, transform="-P1D", ) )