Skip to content

Commit

Permalink
Raise an exception in case of local execution of raw containers tasks (
Browse files Browse the repository at this point in the history
…#1745)

* Raise an exception in case of local raw containers

Signed-off-by: eduardo apolinario <[email protected]>

* Remove ContainerTask tests from test_type_hints

Signed-off-by: eduardo apolinario <[email protected]>

---------

Signed-off-by: eduardo apolinario <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: eduardo apolinario <[email protected]>
  • Loading branch information
2 people authored and Fabio Grätz committed Aug 14, 2023
1 parent 9dbc855 commit 4b1e4a5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 73 deletions.
13 changes: 3 additions & 10 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.context_manager import FlyteContext
from flytekit.core.interface import Interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
Expand Down Expand Up @@ -89,16 +90,8 @@ def __init__(
def resources(self) -> ResourceSpec:
return self._resources

def execute(self, **kwargs) -> Any:
print(kwargs)
env = ""
for k, v in self.environment.items():
env += f" -e {k}={v}"
print(
f"\ndocker run --rm -v /tmp/inputs:{self._input_data_dir} -v /tmp/outputs:{self._output_data_dir} {env}"
f"{self._image} {self._cmd} {self._args}"
)
return None
def local_execute(self, ctx: FlyteContext, **kwargs) -> Any:
raise RuntimeError("ContainerTask is not supported in local executions.")

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container
Expand Down
14 changes: 14 additions & 0 deletions tests/flytekit/unit/core/test_container_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from kubernetes.client.models import (
V1Affinity,
V1NodeAffinity,
Expand Down Expand Up @@ -78,3 +79,16 @@ def test_pod_template():
{"effect": "NoSchedule", "key": "nvidia.com/gpu", "operator": "Exists"}
]
assert serialized_pod_spec["runtimeClassName"] == "nvidia"


def test_local_execution():
ct = ContainerTask(
name="name",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
image="inexistent-image:v42",
command=["some", "command"],
)

with pytest.raises(RuntimeError):
ct()
64 changes: 1 addition & 63 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import flytekit
import flytekit.configuration
from flytekit import ContainerTask, Secret, SQLTask, dynamic, kwtypes, map_task
from flytekit import Secret, SQLTask, dynamic, kwtypes, map_task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig
from flytekit.core import context_manager, launch_plan, promise
from flytekit.core.condition import conditional
Expand Down Expand Up @@ -925,68 +925,6 @@ def my_subwf(a: int) -> typing.Tuple[str, str]:
assert parameter_a.default is not None


def test_wf_container_task():
@task
def t1(a: int) -> (int, str):
return a + 2, str(a) + "-HELLO"

t2 = ContainerTask(
"raw",
image="alpine",
inputs=kwtypes(a=int, b=str),
input_data_dir="/tmp",
output_data_dir="/tmp",
command=["cat"],
arguments=["/tmp/a"],
)

@workflow
def wf(a: int):
x, y = t1(a=a)
t2(a=x, b=y)

with task_mock(t2) as mock:
mock.side_effect = lambda a, b: None
assert t2(a=10, b="hello") is None

wf(a=10)


def test_wf_container_task_multiple():
square = ContainerTask(
name="square",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs=kwtypes(val=int),
outputs=kwtypes(out=int),
image="alpine",
command=["sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"],
)

sum = ContainerTask(
name="sum",
input_data_dir="/var/flyte/inputs",
output_data_dir="/var/flyte/outputs",
inputs=kwtypes(x=int, y=int),
outputs=kwtypes(out=int),
image="alpine",
command=["sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"],
)

@workflow
def raw_container_wf(val1: int, val2: int) -> int:
return sum(x=square(val=val1), y=square(val=val2))

with task_mock(square) as square_mock, task_mock(sum) as sum_mock:
square_mock.side_effect = lambda val: val * val
assert square(val=10) == 100

sum_mock.side_effect = lambda x, y: x + y
assert sum(x=10, y=10) == 20

assert raw_container_wf(val1=10, val2=10) == 200


def test_wf_tuple_fails():
with pytest.raises(RestrictedTypeError):

Expand Down

0 comments on commit 4b1e4a5

Please sign in to comment.