diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 4c9150881d..4fe8e669ab 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -679,6 +679,11 @@ def __rshift__(self, other: typing.Union[Promise, VoidPromise]): if self.ref: self.ref.node.runs_before(other.ref.node) + def with_overrides(self, *args, **kwargs): + if self.ref: + self.ref.node.with_overrides(*args, **kwargs) + return self + @property def task_name(self): return self._task_name diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index a303230386..f6dc9c9ba5 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -372,3 +372,27 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.interruptible == interruptible + + +def test_void_promise_override(): + @task + def t1(a: str): + print(f"*~*~*~{a}*~*~*~") + + @workflow + def my_wf(a: str): + t1(a=a).with_overrides(requests=Resources(cpu="1", mem="100")) + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [ + _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), + _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"), + ]