Skip to content

Commit

Permalink
Query by default when missing (#2379)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Apr 25, 2024
1 parent e3f92cf commit 93f690c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 5 deletions.
19 changes: 19 additions & 0 deletions flytekit/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,25 @@ def __init__(
else:
self.binding = None

@property
def bound(self) -> bool:
if self.artifact.time_partitioned and not (self.time_partition and self.time_partition.value):
return False
if self.artifact.partition_keys:
artifact_partitions = set(self.artifact.partition_keys)
query_partitions = set()
if self.partitions and self.partitions.partitions:
pp = self.partitions.partitions
query_partitions = set([k for k in pp.keys() if pp[k].value])

if artifact_partitions != query_partitions:
logger.error(
f"Query on {self.artifact.name} missing query params {artifact_partitions - query_partitions}"
)
return False

return True

def to_flyte_idl(
self,
**kwargs,
Expand Down
13 changes: 11 additions & 2 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,17 @@ def transform_inputs_to_parameters(
if isinstance(_default, ArtifactQuery):
params[k] = _interface_models.Parameter(var=v, required=False, artifact_query=_default.to_flyte_idl())
elif isinstance(_default, Artifact):
artifact_id = _default.concrete_artifact_id # may raise
params[k] = _interface_models.Parameter(var=v, required=False, artifact_id=artifact_id)
if not _default.version:
# If the artifact is not versioned, assume it's meant to be a query.
q = _default.query()
if q.bound:
params[k] = _interface_models.Parameter(var=v, required=False, artifact_query=q.to_flyte_idl())
else:
raise FlyteValidationException(f"Cannot use default query with artifact {_default.name}")
else:
# If it is versioned, assumed it's intentionally hard-coded
artifact_id = _default.concrete_artifact_id # may raise
params[k] = _interface_models.Parameter(var=v, required=False, artifact_id=artifact_id)
else:
required = _default is None
default_lv = None
Expand Down
36 changes: 33 additions & 3 deletions tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ def test_basic_option_hardcoded_tp():
assert id_spec.time_partition.value.HasField("time_value")


def test_bound_ness():
a1_a = Artifact(name="my_data", partition_keys=["a"])
q = a1_a.query()
assert not q.bound

q = a1_a.query(a="hi")
assert q.bound


def test_bound_ness_time():
a1_t = Artifact(name="my_data", time_partitioned=True)
q = a1_t.query()
assert not q.bound

q = a1_t.query(time_partition=Inputs.dt)
assert q.bound


def test_basic_option_a():
import pandas as pd

Expand Down Expand Up @@ -362,7 +380,6 @@ def test_artifact_as_promise_query():

@task
def t1(a: CustomReturn) -> CustomReturn:
print(a)
return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]})

@workflow
Expand All @@ -378,6 +395,19 @@ def wf(a: CustomReturn = wf_artifact.query()):
assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.domain == "dev"
assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.name == "wf_artifact"

# Test non-specified query for unpartitioned artifacts
@workflow
def wf2(a: CustomReturn = wf_artifact):
u = t1(a=a)
return u

lp2 = LaunchPlan.get_default_launch_plan(ctx, wf2)
entities = OrderedDict()
spec = get_serializable(entities, serialization_settings, lp2)
assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.project == "project1"
assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.domain == "dev"
assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.name == "wf_artifact"


def test_artifact_as_promise():
# when the full artifact is specified, the artifact should be bindable as a literal
Expand All @@ -389,12 +419,12 @@ def t1(a: CustomReturn) -> CustomReturn:
return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]})

@workflow
def wf2(a: CustomReturn = wf_artifact):
def wf3(a: CustomReturn = wf_artifact):
u = t1(a=a)
return u

ctx = FlyteContextManager.current_context()
lp = LaunchPlan.get_default_launch_plan(ctx, wf2)
lp = LaunchPlan.get_default_launch_plan(ctx, wf3)
entities = OrderedDict()
spec = get_serializable(entities, serialization_settings, lp)
assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.project == "pro"
Expand Down

0 comments on commit 93f690c

Please sign in to comment.