Skip to content

Commit

Permalink
Allow overrides of timeout, retries and interruptible in nodes (#593)
Browse files Browse the repository at this point in the history
Signed-off-by: Jeev B <[email protected]>
  • Loading branch information
jeevb authored Aug 12, 2021
1 parent 3117fa4 commit 9ac450c
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
18 changes: 18 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import typing
from typing import Any, List

Expand Down Expand Up @@ -89,6 +90,23 @@ def with_overrides(self, *args, **kwargs):
requests = _convert_resource_overrides(kwargs.get("requests"), "requests")
limits = _convert_resource_overrides(kwargs.get("limits"), "limits")
self._resources = _resources_model(requests=requests, limits=limits)
if "timeout" in kwargs:
timeout = kwargs["timeout"]
if timeout is None:
self._metadata._timeout = datetime.timedelta()
elif isinstance(timeout, int):
self._metadata._timeout = datetime.timedelta(seconds=timeout)
elif isinstance(timeout, datetime.timedelta):
self._metadata._timeout = timeout
else:
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if "retries" in kwargs:
retries = kwargs["retries"]
self._metadata._retries = (
_literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries)
)
if "interruptible" in kwargs:
self._metadata._interruptible = kwargs["interruptible"]
return self


Expand Down
85 changes: 85 additions & 0 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import typing
from collections import OrderedDict

Expand All @@ -12,6 +13,7 @@
from flytekit.core.node_creation import create_node
from flytekit.core.task import task
from flytekit.core.workflow import workflow
from flytekit.models import literals as _literal_models
from flytekit.models.task import Resources as _resources_models


Expand Down Expand Up @@ -258,3 +260,86 @@ def my_wf(a: typing.List[str]) -> typing.List[str]:
_resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"),
_resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "1Gi"),
]


@pytest.mark.parametrize(
"timeout,expected",
[(None, datetime.timedelta()), (10, datetime.timedelta(seconds=10))],
)
def test_timeout_override(timeout, expected):
@task
def t1(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(timeout=timeout)

serialization_settings = context_manager.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].metadata.timeout == expected


def test_timeout_override_invalid_value():
@task
def t1(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

with pytest.raises(ValueError, match="datetime.timedelta or int seconds"):

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(timeout="foo")


@pytest.mark.parametrize(
"retries,expected", [(None, _literal_models.RetryStrategy(0)), (3, _literal_models.RetryStrategy(3))]
)
def test_retries_override(retries, expected):
@task
def t1(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(retries=retries)

serialization_settings = context_manager.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].metadata.retries == expected


@pytest.mark.parametrize("interruptible", [None, True, False])
def test_interruptible_override(interruptible):
@task
def t1(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(interruptible=interruptible)

serialization_settings = context_manager.SerializationSettings(
project="test_proj",
domain="test_domain",
version="abc",
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].metadata.interruptible == interruptible

0 comments on commit 9ac450c

Please sign in to comment.