Skip to content

Commit

Permalink
Improve error handling in ShellTask (#1732)
Browse files Browse the repository at this point in the history
* Improve error handling in ShellTask

Signed-off-by: Pradithya Aria <[email protected]>

* Add new line

Signed-off-by: Pradithya Aria <[email protected]>

* Capture stdout

Signed-off-by: Pradithya Aria <[email protected]>

---------

Signed-off-by: Pradithya Aria <[email protected]>
Co-authored-by: Pradithya Aria <[email protected]>
  • Loading branch information
pradithya and Pradithya Aria authored Jul 18, 2023
1 parent fac7085 commit 604cabd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
51 changes: 39 additions & 12 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flytekit.core.interface import Interface
from flytekit.core.python_function_task import PythonInstanceTask
from flytekit.core.task import TaskPlugins
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.loggers import logger
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
Expand Down Expand Up @@ -98,6 +99,30 @@ def interpolate(
T = typing.TypeVar("T")


def _run_script(script) -> typing.Tuple[int, str, str]:
"""
Run script as a subprocess and return the returncode, stdout, and stderr.
While executing the su process, stdout of the subprocess will be printed
to the current process stdout so that the subprocess execution will not appear unresponsive
:param script: script to be executed
:type script: str
:return: tuple containing the process returncode, stdout, and stderr
:rtype: typing.Tuple[int, str, str]
"""
process = subprocess.Popen(script, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0, shell=True, text=True)

# print stdout so that long-running subprocess will not appear unresponsive
out = ""
for line in process.stdout:
print(line)
out += line

code = process.wait()
return code, out, process.stderr.read()


class ShellTask(PythonInstanceTask[T]):
""" """

Expand Down Expand Up @@ -213,21 +238,23 @@ def execute(self, **kwargs) -> typing.Any:
print(gen_script)
print("\n==============================================\n")

try:
if platform.system() == "Windows" and os.environ.get("ComSpec") is None:
# https://github.com/python/cpython/issues/101283
os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe"
subprocess.check_call(gen_script, shell=True)
except subprocess.CalledProcessError as e:
if platform.system() == "Windows" and os.environ.get("ComSpec") is None:
# https://github.com/python/cpython/issues/101283
os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe"

returncode, stdout, stderr = _run_script(gen_script)
if returncode != 0:
files = os.listdir(".")
fstr = "\n-".join(files)
logger.error(
f"Failed to Execute Script, return-code {e.returncode} \n"
f"StdErr: {e.stderr}\n"
f"StdOut: {e.stdout}\n"
f" Current directory contents: .\n-{fstr}"
error = (
f"Failed to Execute Script, return-code {returncode} \n"
f"Current directory contents: .\n-{fstr}\n"
f"StdOut: {stdout}\n"
f"StdErr: {stderr}\n"
)
raise
logger.error(error)
# raise FlyteRecoverableException so that it's classified as user error and will be retried
raise FlyteRecoverableException(error)

final_outputs = []
for v in self._output_locs:
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/extras/tasks/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import tempfile
import typing
from dataclasses import dataclass
from subprocess import CalledProcessError

import pytest
from dataclasses_json import dataclass_json

import flytekit
from flytekit import kwtypes
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.extras.tasks.shell import OutputLocation, RawShellTask, ShellTask, get_raw_shell_task
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import CSVFile, FlyteFile
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_input_substitution_primitive():

t(f=os.path.join(test_file_path, "__init__.py"), y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
t(f=os.path.join(test_file_path, "test_shell.py"), y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
with pytest.raises(CalledProcessError):
with pytest.raises(FlyteRecoverableException):
t(f="non_exist.py", y=5, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


Expand Down

0 comments on commit 604cabd

Please sign in to comment.