Skip to content

Commit

Permalink
Add support for string-format-like sytax for shell task (flyteorg#792)
Browse files Browse the repository at this point in the history
* POC: Add support for f-string like sytax for shell task

This commit is a proof of concept adding f-string like syntax for
shell_tasks. This supports using nested types for script inputs, such
as data classes. This change was motivated by the desire to combine
shell_tasks that have multiple inputs with map_tasks which only support
tasks with a single input.

This commit is only a starting point, since it makes some changes to the
shell_task API (adds a template_style field), and modifies some of the default
behavior for ease of implementation (e.g. throwing an error when there
are unused input arguments).

Signed-off-by: Zach Palchick <[email protected]>

* Drop support for old/regex style for doing string interpolation

Signed-off-by: Zach Palchick <[email protected]>
  • Loading branch information
palchicz authored and kennyworkman committed Feb 8, 2022
1 parent 11b3e9b commit f95993a
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 87 deletions.
96 changes: 47 additions & 49 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import collections
import datetime
import logging
import os
import re
import string
import subprocess
import typing
from dataclasses import dataclass
Expand Down Expand Up @@ -30,46 +31,6 @@ class OutputLocation:
location: typing.Union[os.PathLike, str]


def _stringify(v: typing.Any) -> str:
"""
Special cased return for the given value. Given the type returns the string version for the type.
Handles FlyteFile and FlyteDirectory specially. Downloads and returns the downloaded filepath
"""
if isinstance(v, FlyteFile):
v.download()
return v.path
if isinstance(v, FlyteDirectory):
v.download()
return v.path
if isinstance(v, datetime.datetime):
return v.isoformat()
return str(v)


def _interpolate(tmpl: str, regex: re.Pattern, validate_all_match: bool = True, **kwargs) -> str:
"""
Substitutes all templates that match the supplied regex
with the given inputs and returns the substituted string. The result is non destructive towards the given string.
"""
modified = tmpl
matched = set()
for match in regex.finditer(tmpl):
expr = match.groups()[0]
var = match.groups()[1]
if var not in kwargs:
raise ValueError(f"Variable {var} in Query (part of {expr}) not found in inputs {kwargs.keys()}")
matched.add(var)
val = kwargs[var]
# str conversion should be deliberate, with right conversion for each type
modified = modified.replace(expr, _stringify(val))

if validate_all_match:
if len(matched) < len(kwargs.keys()):
diff = set(kwargs.keys()).difference(matched)
raise ValueError(f"Extra Inputs have no matches in script template - missing {diff}")
return modified


def _dummy_task_func():
"""
A Fake function to satisfy the inner PythonTask requirements
Expand All @@ -80,12 +41,51 @@ def _dummy_task_func():
T = typing.TypeVar("T")


class _PythonFStringInterpolizer:
"""A class for interpolating scripts that use python string.format syntax"""

class _Formatter(string.Formatter):
def format_field(self, value, format_spec):
"""
Special cased return for the given value. Given the type returns the string version for
the type. Handles FlyteFile and FlyteDirectory specially.
Downloads and returns the downloaded filepath.
"""
if isinstance(value, FlyteFile):
value.download()
return value.path
if isinstance(value, FlyteDirectory):
value.download()
return value.path
if isinstance(value, datetime.datetime):
return value.isoformat()
return super().format_field(value, format_spec)

def interpolate(
self,
tmpl: str,
inputs: typing.Optional[typing.Dict[str, str]] = None,
outputs: typing.Optional[typing.Dict[str, str]] = None,
) -> str:
"""
Interpolate python formatted string templates with variables from the input and output
argument dicts. The result is non destructive towards the given template string.
"""
inputs = inputs or {}
outputs = outputs or {}
reused_vars = inputs.keys() & outputs.keys()
if reused_vars:
raise ValueError(f"Variables {reused_vars} in Query cannot be shared between inputs and outputs.")
consolidated_args = collections.ChainMap(inputs, outputs)
try:
return self._Formatter().format(tmpl, **consolidated_args)
except KeyError as e:
raise ValueError(f"Variable {e} in Query not found in inputs {consolidated_args.keys()}")


class ShellTask(PythonInstanceTask[T]):
""" """

_INPUT_REGEX = re.compile(r"({{\s*.inputs.(\w+)\s*}})", re.IGNORECASE)
_OUTPUT_REGEX = re.compile(r"({{\s*.outputs.(\w+)\s*}})", re.IGNORECASE)

def __init__(
self,
name: str,
Expand Down Expand Up @@ -136,6 +136,7 @@ def __init__(
self._script_file = script_file
self._debug = debug
self._output_locs = output_locs if output_locs else []
self._interpolizer = _PythonFStringInterpolizer()
outputs = self._validate_output_locs()
super().__init__(
name,
Expand Down Expand Up @@ -184,12 +185,9 @@ def execute(self, **kwargs) -> typing.Any:
outputs: typing.Dict[str, str] = {}
if self._output_locs:
for v in self._output_locs:
outputs[v.var] = _interpolate(v.location, self._INPUT_REGEX, validate_all_match=False, **kwargs)
outputs[v.var] = self._interpolizer.interpolate(v.location, inputs=kwargs)

gen_script = _interpolate(self._script, self._INPUT_REGEX, **kwargs)
# For outputs it is not necessary that all outputs are used in the script, some are implicit outputs
# for example gcc main.c will generate a.out automatically
gen_script = _interpolate(gen_script, self._OUTPUT_REGEX, validate_all_match=False, **outputs)
gen_script = self._interpolizer.interpolate(self._script, inputs=kwargs, outputs=outputs)
if self._debug:
print("\n==============================================\n")
print(gen_script)
Expand Down
106 changes: 70 additions & 36 deletions tests/flytekit/unit/extras/tasks/test_shell.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import datetime
import os
import tempfile
from dataclasses import dataclass
from subprocess import CalledProcessError

import pytest
from dataclasses_json import dataclass_json

from flytekit import kwtypes
from flytekit.extras.tasks.shell import OutputLocation, ShellTask
Expand Down Expand Up @@ -43,10 +45,10 @@ def test_input_substitution_primitive():
t = ShellTask(
name="test",
script="""
set -ex
cat {{ .inputs.f }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}"
""",
set -ex
cat {f}
echo "Hello World {y} on {j}"
""",
inputs=kwtypes(f=str, y=int, j=datetime.datetime),
)

Expand All @@ -60,30 +62,28 @@ def test_input_substitution_files():
t = ShellTask(
name="test",
script="""
cat {{ .inputs.f }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}"
""",
cat {f}
echo "Hello World {y} on {j}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
)

assert t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) is None


def test_input_output_substitution_files():
s = """
cat {{ .inputs.f }} > {{ .outputs.y }}
"""
script = "cat {f} > {y}"
t = ShellTask(
name="test",
debug=True,
script=s,
script=script,
inputs=kwtypes(f=CSVFile),
output_locs=[
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.mod"),
OutputLocation(var="y", var_type=FlyteFile, location="{f}.mod"),
],
)

assert t.script == s
assert t.script == script

contents = "1,2,3,4\n"
with tempfile.TemporaryDirectory() as tmp:
Expand All @@ -100,70 +100,104 @@ def test_input_output_substitution_files():


def test_input_single_output_substitution_files():
s = """
cat {{ .inputs.f }} >> {{ .outputs.y }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}"
"""
script = """
cat {f} >> {z}
echo "Hello World {y} on {j}"
"""
t = ShellTask(
name="test",
debug=True,
script=s,
script=script,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc")],
output_locs=[OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc")],
)

assert t.script == s
assert t.script == script
y = t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))
assert y.path[-4:] == ".pyc"


def test_input_output_extra_var_in_template():
@pytest.mark.parametrize(
"script",
[
(
"""
cat {missing} >> {z}
echo "Hello World {y} on {j} - output {x}"
"""
),
(
"""
cat {f} {missing} >> {z}
echo "Hello World {y} on {j} - output {x}"
"""
),
],
)
def test_input_output_extra_and_missing_variables(script):
t = ShellTask(
name="test",
debug=True,
script="""
cat {{ .inputs.f }} {{ .inputs.missing }} >> {{ .outputs.y }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}"
""",
script=script,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"),
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"),
OutputLocation(var="x", var_type=FlyteDirectory, location="{y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc"),
],
)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match="missing"):
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


def test_input_output_extra_input():
def test_cannot_reuse_variables_for_both_inputs_and_outputs():
t = ShellTask(
name="test",
debug=True,
script="""
cat {{ .inputs.missing }} >> {{ .outputs.y }}
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}"
cat {f} >> {y}
echo "Hello World {y} on {j}"
""",
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"),
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"),
OutputLocation(var="y", var_type=FlyteFile, location="{f}.pyc"),
],
)

with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Variables {'y'} in Query"):
t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0))


def test_can_use_complex_types_for_inputs_to_f_string_template():
@dataclass_json
@dataclass
class InputArgs:
in_file: CSVFile

t = ShellTask(
name="test",
debug=True,
script="""cat {input_args.in_file} >> {input_args.in_file}.tmp""",
inputs=kwtypes(input_args=InputArgs),
output_locs=[
OutputLocation(var="x", var_type=FlyteFile, location="{input_args.in_file}.tmp"),
],
)

input_args = InputArgs(FlyteFile(path=test_csv))
x = t(input_args=input_args)
assert x.path[-4:] == ".tmp"


def test_shell_script():
t = ShellTask(
name="test2",
debug=True,
script_file=script_sh,
inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
output_locs=[
OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"),
OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"),
OutputLocation(var="x", var_type=FlyteDirectory, location="{y}"),
OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc"),
],
)

Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/extras/tasks/testdata/script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

set -ex

cat "{{ .inputs.f }}" >> "{{ .outputs.y }}"
echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}"
cat "{f}" >> "{z}"
echo "Hello World {y} on {j} - output {x}"

0 comments on commit f95993a

Please sign in to comment.