diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index 4f1db36389..f2e08042bc 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -183,6 +183,25 @@ def __init__( else: self.binding = None + @property + def bound(self) -> bool: + if self.artifact.time_partitioned and not (self.time_partition and self.time_partition.value): + return False + if self.artifact.partition_keys: + artifact_partitions = set(self.artifact.partition_keys) + query_partitions = set() + if self.partitions and self.partitions.partitions: + pp = self.partitions.partitions + query_partitions = set([k for k in pp.keys() if pp[k].value]) + + if artifact_partitions != query_partitions: + logger.error( + f"Query on {self.artifact.name} missing query params {artifact_partitions - query_partitions}" + ) + return False + + return True + def to_flyte_idl( self, **kwargs, diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index aecca2936d..c139641278 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -225,8 +225,17 @@ def transform_inputs_to_parameters( if isinstance(_default, ArtifactQuery): params[k] = _interface_models.Parameter(var=v, required=False, artifact_query=_default.to_flyte_idl()) elif isinstance(_default, Artifact): - artifact_id = _default.concrete_artifact_id # may raise - params[k] = _interface_models.Parameter(var=v, required=False, artifact_id=artifact_id) + if not _default.version: + # If the artifact is not versioned, assume it's meant to be a query. + q = _default.query() + if q.bound: + params[k] = _interface_models.Parameter(var=v, required=False, artifact_query=q.to_flyte_idl()) + else: + raise FlyteValidationException(f"Cannot use default query with artifact {_default.name}") + else: + # If it is versioned, assumed it's intentionally hard-coded + artifact_id = _default.concrete_artifact_id # may raise + params[k] = _interface_models.Parameter(var=v, required=False, artifact_id=artifact_id) else: required = _default is None default_lv = None diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index 1580832426..c026d1b3ce 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -164,6 +164,24 @@ def test_basic_option_hardcoded_tp(): assert id_spec.time_partition.value.HasField("time_value") +def test_bound_ness(): + a1_a = Artifact(name="my_data", partition_keys=["a"]) + q = a1_a.query() + assert not q.bound + + q = a1_a.query(a="hi") + assert q.bound + + +def test_bound_ness_time(): + a1_t = Artifact(name="my_data", time_partitioned=True) + q = a1_t.query() + assert not q.bound + + q = a1_t.query(time_partition=Inputs.dt) + assert q.bound + + def test_basic_option_a(): import pandas as pd @@ -362,7 +380,6 @@ def test_artifact_as_promise_query(): @task def t1(a: CustomReturn) -> CustomReturn: - print(a) return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]}) @workflow @@ -378,6 +395,19 @@ def wf(a: CustomReturn = wf_artifact.query()): assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.domain == "dev" assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.name == "wf_artifact" + # Test non-specified query for unpartitioned artifacts + @workflow + def wf2(a: CustomReturn = wf_artifact): + u = t1(a=a) + return u + + lp2 = LaunchPlan.get_default_launch_plan(ctx, wf2) + entities = OrderedDict() + spec = get_serializable(entities, serialization_settings, lp2) + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.project == "project1" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.domain == "dev" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.name == "wf_artifact" + def test_artifact_as_promise(): # when the full artifact is specified, the artifact should be bindable as a literal @@ -389,12 +419,12 @@ def t1(a: CustomReturn) -> CustomReturn: return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]}) @workflow - def wf2(a: CustomReturn = wf_artifact): + def wf3(a: CustomReturn = wf_artifact): u = t1(a=a) return u ctx = FlyteContextManager.current_context() - lp = LaunchPlan.get_default_launch_plan(ctx, wf2) + lp = LaunchPlan.get_default_launch_plan(ctx, wf3) entities = OrderedDict() spec = get_serializable(entities, serialization_settings, lp) assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.project == "pro"