Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling in ShellTask #1732

Merged
merged 3 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flytekit.core.interface import Interface
from flytekit.core.python_function_task import PythonInstanceTask
from flytekit.core.task import TaskPlugins
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.loggers import logger
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
Expand Down Expand Up @@ -98,6 +99,30 @@ def interpolate(
T = typing.TypeVar("T")


def _run_script(script) -> typing.Tuple[int, str, str]:
"""
Run script as a subprocess and return the returncode, stdout, and stderr.

While executing the su process, stdout of the subprocess will be printed
to the current process stdout so that the subprocess execution will not appear unresponsive

:param script: script to be executed
:type script: str
:return: tuple containing the process returncode, stdout, and stderr
:rtype: typing.Tuple[int, str, str]
"""
process = subprocess.Popen(script, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0, shell=True, text=True)

# print stdout so that long-running subprocess will not appear unresponsive
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: Currently, the pod's log didn't show any stdout and stderr?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it does. However this PR capture the stdout thus need to emit it manually.

out = ""
for line in process.stdout:
print(line)
out += line

code = process.wait()
return code, out, process.stderr.read()


class ShellTask(PythonInstanceTask[T]):
""" """

Expand Down Expand Up @@ -213,21 +238,23 @@ def execute(self, **kwargs) -> typing.Any:
print(gen_script)
print("\n==============================================\n")

try:
if platform.system() == "Windows" and os.environ.get("ComSpec") is None:
# https://github.com/python/cpython/issues/101283
os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe"
subprocess.check_call(gen_script, shell=True)
except subprocess.CalledProcessError as e:
if platform.system() == "Windows" and os.environ.get("ComSpec") is None:
# https://github.com/python/cpython/issues/101283
os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe"

returncode, stdout, stderr = _run_script(gen_script)
if returncode != 0:
files = os.listdir(".")
fstr = "\n-".join(files)
logger.error(
f"Failed to Execute Script, return-code {e.returncode} \n"
f"StdErr: {e.stderr}\n"
f"StdOut: {e.stdout}\n"
f" Current directory contents: .\n-{fstr}"
error = (
f"Failed to Execute Script, return-code {returncode} \n"
f"Current directory contents: .\n-{fstr}\n"
f"StdOut: {stdout}\n"
f"StdErr: {stderr}\n"
)
raise
logger.error(error)
# raise FlyteRecoverableException so that it's classified as user error and will be retried
raise FlyteRecoverableException(error)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not every user error should be a recoverable one 🤔

Take a look at how PythonFunctionTask separates user vs system error scopes:

return exception_scopes.user_entry_point(self._task_function)(**kwargs)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not every user error should be a recoverable one

I agree, however, since it's a shell execution I am not sure how can we differentiate recoverable vs non-recoverable user error. I am defaulting to recoverable error so that the users of ShellTask can still leverage the task retry feature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, of course, makes sense 👍
This change is not really backwards compatible though, I wonder whether this is a dealbreaker 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is undoubtedly a different behavior change when a shell task throws an error.

Before: Error in shell task is considered system error thus retried N number of times, where N is 3 by default.

After: Error in shell task is considered user recoverable error. By default, it's not retried and will only be retried if users specify retry in the task's metadata.

However, the shell task didn't respect the retry strategy set by users, and I think it can be considered as bug.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wild-endeavor your opinion will be very helpful here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I think shell task should throw user recoverable error


final_outputs = []
for v in self._output_locs:
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/extras/tasks/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import tempfile
import typing
from dataclasses import dataclass
from subprocess import CalledProcessError

import pytest
from dataclasses_json import dataclass_json

import flytekit
from flytekit import kwtypes
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.extras.tasks.shell import OutputLocation, RawShellTask, ShellTask, get_raw_shell_task
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import CSVFile, FlyteFile
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_input_substitution_primitive():

t(f=os.path.join(test_file_path, "__init__.py"), y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
t(f=os.path.join(test_file_path, "test_shell.py"), y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
with pytest.raises(CalledProcessError):
with pytest.raises(FlyteRecoverableException):
t(f="non_exist.py", y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


Expand Down