Skip to content

Commit

Permalink
Add support for overriding task configurations (#1410)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
pingsutw authored and samhita-alla committed Feb 2, 2023
1 parent eaa3115 commit 3436c8e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
7 changes: 7 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.task import Resources as _resources_model
Expand Down Expand Up @@ -119,6 +120,12 @@ def with_overrides(self, *args, **kwargs):
self._metadata._interruptible = kwargs["interruptible"]
if "name" in kwargs:
self._metadata._name = kwargs["name"]
if "task_config" in kwargs:
logger.warning("This override is beta. We may want to revisit this in the future.")
new_task_config = kwargs["task_config"]
if not isinstance(new_task_config, type(self.flyte_entity._task_config)):
raise ValueError("can't change the type of the task config")
self.flyte_entity._task_config = new_task_config
return self


Expand Down
23 changes: 23 additions & 0 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import typing
from collections import OrderedDict
from dataclasses import dataclass

import pytest

Expand Down Expand Up @@ -424,3 +425,25 @@ def my_wf(a: str) -> str:
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].metadata.name == "foo"


def test_config_override():
@dataclass
class DummyConfig:
name: str

@task(task_config=DummyConfig(name="hello"))
def t1(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(task_config=DummyConfig("flyte"))

assert my_wf.nodes[0].flyte_entity.task_config.name == "flyte"

with pytest.raises(ValueError):

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(task_config=None)

0 comments on commit 3436c8e

Please sign in to comment.