From 79d7c37f889a4202a613e45ba4fb6e7c9d021ab7 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 4 Jan 2023 17:29:52 -0800 Subject: [PATCH 1/5] wip Signed-off-by: Kevin Su --- flytekit/core/node.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index d8b43f2728..4654802a01 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -114,6 +114,11 @@ def with_overrides(self, *args, **kwargs): self._metadata._interruptible = kwargs["interruptible"] if "name" in kwargs: self._metadata._name = kwargs["name"] + if "task_config" in kwargs: + new_task_config = kwargs["task_config"] + # if type(new_task_config) != type(self._task_config): + self._task_config = new_task_config + print("new task type", new_task_config) return self From fd74db951bf7489be2f539e318657796c6c32bcd Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 11 Jan 2023 12:00:34 -0800 Subject: [PATCH 2/5] Override task config Signed-off-by: Kevin Su --- flytekit/core/node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 4654802a01..76a44a9780 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -116,9 +116,9 @@ def with_overrides(self, *args, **kwargs): self._metadata._name = kwargs["name"] if "task_config" in kwargs: new_task_config = kwargs["task_config"] - # if type(new_task_config) != type(self._task_config): - self._task_config = new_task_config - print("new task type", new_task_config) + if not isinstance(new_task_config, type(self.flyte_entity._task_config)): + raise ValueError("can't change the type of the task config") + self.flyte_entity._task_config = new_task_config return self From 03666e32d79e1e814ec79a9aa1694bcd733aec52 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 11 Jan 2023 13:25:09 -0800 Subject: [PATCH 3/5] Add tests Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_node_creation.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 47c8af9830..4caa0e9075 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -1,6 +1,7 @@ import datetime import typing from collections import OrderedDict +from dataclasses import dataclass import pytest @@ -424,3 +425,19 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.name == "foo" + + +def test_config_override(): + @dataclass + class DummyConfig: + name: str + + @task(task_config=DummyConfig(name="hello")) + def t1(a: str) -> str: + return f"*~*~*~{a}*~*~*~" + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(task_config=DummyConfig("flyte")) + + assert my_wf.nodes[0].flyte_entity.task_config.name == "flyte" From 2f5e5147a15098abc4b91b7eda67832ffc8b5070 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 11 Jan 2023 13:30:15 -0800 Subject: [PATCH 4/5] Update tests Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_node_creation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 4caa0e9075..2813563fb9 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -441,3 +441,9 @@ def my_wf(a: str) -> str: return t1(a=a).with_overrides(task_config=DummyConfig("flyte")) assert my_wf.nodes[0].flyte_entity.task_config.name == "flyte" + + with pytest.raises(ValueError): + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(task_config=None) From a92aa24b1cf8805d41f9a48e927398fb0cb5e890 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 12 Jan 2023 15:26:33 -0800 Subject: [PATCH 5/5] add logger Signed-off-by: Kevin Su --- flytekit/core/node.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 76a44a9780..617790746f 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -6,6 +6,7 @@ from flytekit.core.resources import Resources from flytekit.core.utils import _dnsify +from flytekit.loggers import logger from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.task import Resources as _resources_model @@ -115,6 +116,7 @@ def with_overrides(self, *args, **kwargs): if "name" in kwargs: self._metadata._name = kwargs["name"] if "task_config" in kwargs: + logger.warning("This override is beta. We may want to revisit this in the future.") new_task_config = kwargs["task_config"] if not isinstance(new_task_config, type(self.flyte_entity._task_config)): raise ValueError("can't change the type of the task config")