diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 6e8dbcc21b5..63bae594edf 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -1,7 +1,8 @@ +import collections import datetime import logging import os -import re +import string import subprocess import typing from dataclasses import dataclass @@ -30,46 +31,6 @@ class OutputLocation: location: typing.Union[os.PathLike, str] -def _stringify(v: typing.Any) -> str: - """ - Special cased return for the given value. Given the type returns the string version for the type. - Handles FlyteFile and FlyteDirectory specially. Downloads and returns the downloaded filepath - """ - if isinstance(v, FlyteFile): - v.download() - return v.path - if isinstance(v, FlyteDirectory): - v.download() - return v.path - if isinstance(v, datetime.datetime): - return v.isoformat() - return str(v) - - -def _interpolate(tmpl: str, regex: re.Pattern, validate_all_match: bool = True, **kwargs) -> str: - """ - Substitutes all templates that match the supplied regex - with the given inputs and returns the substituted string. The result is non destructive towards the given string. - """ - modified = tmpl - matched = set() - for match in regex.finditer(tmpl): - expr = match.groups()[0] - var = match.groups()[1] - if var not in kwargs: - raise ValueError(f"Variable {var} in Query (part of {expr}) not found in inputs {kwargs.keys()}") - matched.add(var) - val = kwargs[var] - # str conversion should be deliberate, with right conversion for each type - modified = modified.replace(expr, _stringify(val)) - - if validate_all_match: - if len(matched) < len(kwargs.keys()): - diff = set(kwargs.keys()).difference(matched) - raise ValueError(f"Extra Inputs have no matches in script template - missing {diff}") - return modified - - def _dummy_task_func(): """ A Fake function to satisfy the inner PythonTask requirements @@ -80,12 +41,51 @@ def _dummy_task_func(): T = typing.TypeVar("T") +class _PythonFStringInterpolizer: + """A class for interpolating scripts that use python string.format syntax""" + + class _Formatter(string.Formatter): + def format_field(self, value, format_spec): + """ + Special cased return for the given value. Given the type returns the string version for + the type. Handles FlyteFile and FlyteDirectory specially. + Downloads and returns the downloaded filepath. + """ + if isinstance(value, FlyteFile): + value.download() + return value.path + if isinstance(value, FlyteDirectory): + value.download() + return value.path + if isinstance(value, datetime.datetime): + return value.isoformat() + return super().format_field(value, format_spec) + + def interpolate( + self, + tmpl: str, + inputs: typing.Optional[typing.Dict[str, str]] = None, + outputs: typing.Optional[typing.Dict[str, str]] = None, + ) -> str: + """ + Interpolate python formatted string templates with variables from the input and output + argument dicts. The result is non destructive towards the given template string. + """ + inputs = inputs or {} + outputs = outputs or {} + reused_vars = inputs.keys() & outputs.keys() + if reused_vars: + raise ValueError(f"Variables {reused_vars} in Query cannot be shared between inputs and outputs.") + consolidated_args = collections.ChainMap(inputs, outputs) + try: + return self._Formatter().format(tmpl, **consolidated_args) + except KeyError as e: + raise ValueError(f"Variable {e} in Query not found in inputs {consolidated_args.keys()}") + + class ShellTask(PythonInstanceTask[T]): """ """ - _INPUT_REGEX = re.compile(r"({{\s*.inputs.(\w+)\s*}})", re.IGNORECASE) - _OUTPUT_REGEX = re.compile(r"({{\s*.outputs.(\w+)\s*}})", re.IGNORECASE) - def __init__( self, name: str, @@ -136,6 +136,7 @@ def __init__( self._script_file = script_file self._debug = debug self._output_locs = output_locs if output_locs else [] + self._interpolizer = _PythonFStringInterpolizer() outputs = self._validate_output_locs() super().__init__( name, @@ -184,12 +185,9 @@ def execute(self, **kwargs) -> typing.Any: outputs: typing.Dict[str, str] = {} if self._output_locs: for v in self._output_locs: - outputs[v.var] = _interpolate(v.location, self._INPUT_REGEX, validate_all_match=False, **kwargs) + outputs[v.var] = self._interpolizer.interpolate(v.location, inputs=kwargs) - gen_script = _interpolate(self._script, self._INPUT_REGEX, **kwargs) - # For outputs it is not necessary that all outputs are used in the script, some are implicit outputs - # for example gcc main.c will generate a.out automatically - gen_script = _interpolate(gen_script, self._OUTPUT_REGEX, validate_all_match=False, **outputs) + gen_script = self._interpolizer.interpolate(self._script, inputs=kwargs, outputs=outputs) if self._debug: print("\n==============================================\n") print(gen_script) diff --git a/tests/flytekit/unit/extras/tasks/test_shell.py b/tests/flytekit/unit/extras/tasks/test_shell.py index 39f114a9c75..2b76f7ad7fe 100644 --- a/tests/flytekit/unit/extras/tasks/test_shell.py +++ b/tests/flytekit/unit/extras/tasks/test_shell.py @@ -1,9 +1,11 @@ import datetime import os import tempfile +from dataclasses import dataclass from subprocess import CalledProcessError import pytest +from dataclasses_json import dataclass_json from flytekit import kwtypes from flytekit.extras.tasks.shell import OutputLocation, ShellTask @@ -43,10 +45,10 @@ def test_input_substitution_primitive(): t = ShellTask( name="test", script=""" - set -ex - cat {{ .inputs.f }} - echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}" - """, + set -ex + cat {f} + echo "Hello World {y} on {j}" + """, inputs=kwtypes(f=str, y=int, j=datetime.datetime), ) @@ -60,9 +62,9 @@ def test_input_substitution_files(): t = ShellTask( name="test", script=""" - cat {{ .inputs.f }} - echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}" - """, + cat {f} + echo "Hello World {y} on {j}" + """, inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), ) @@ -70,20 +72,18 @@ def test_input_substitution_files(): def test_input_output_substitution_files(): - s = """ - cat {{ .inputs.f }} > {{ .outputs.y }} - """ + script = "cat {f} > {y}" t = ShellTask( name="test", debug=True, - script=s, + script=script, inputs=kwtypes(f=CSVFile), output_locs=[ - OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.mod"), + OutputLocation(var="y", var_type=FlyteFile, location="{f}.mod"), ], ) - assert t.script == s + assert t.script == script contents = "1,2,3,4\n" with tempfile.TemporaryDirectory() as tmp: @@ -100,61 +100,95 @@ def test_input_output_substitution_files(): def test_input_single_output_substitution_files(): - s = """ - cat {{ .inputs.f }} >> {{ .outputs.y }} - echo "Hello World {{ .inputs.y }} on {{ .inputs.j }}" - """ + script = """ + cat {f} >> {z} + echo "Hello World {y} on {j}" + """ t = ShellTask( name="test", debug=True, - script=s, + script=script, inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), - output_locs=[OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc")], + output_locs=[OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc")], ) - assert t.script == s + assert t.script == script y = t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) assert y.path[-4:] == ".pyc" -def test_input_output_extra_var_in_template(): +@pytest.mark.parametrize( + "script", + [ + ( + """ + cat {missing} >> {z} + echo "Hello World {y} on {j} - output {x}" + """ + ), + ( + """ + cat {f} {missing} >> {z} + echo "Hello World {y} on {j} - output {x}" + """ + ), + ], +) +def test_input_output_extra_and_missing_variables(script): t = ShellTask( name="test", debug=True, - script=""" - cat {{ .inputs.f }} {{ .inputs.missing }} >> {{ .outputs.y }} - echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}" - """, + script=script, inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), output_locs=[ - OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"), - OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"), + OutputLocation(var="x", var_type=FlyteDirectory, location="{y}"), + OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc"), ], ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="missing"): t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) -def test_input_output_extra_input(): +def test_cannot_reuse_variables_for_both_inputs_and_outputs(): t = ShellTask( name="test", debug=True, script=""" - cat {{ .inputs.missing }} >> {{ .outputs.y }} - echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}" + cat {f} >> {y} + echo "Hello World {y} on {j}" """, inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), output_locs=[ - OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"), - OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"), + OutputLocation(var="y", var_type=FlyteFile, location="{f}.pyc"), ], ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Variables {'y'} in Query"): t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) +def test_can_use_complex_types_for_inputs_to_f_string_template(): + @dataclass_json + @dataclass + class InputArgs: + in_file: CSVFile + + t = ShellTask( + name="test", + debug=True, + script="""cat {input_args.in_file} >> {input_args.in_file}.tmp""", + inputs=kwtypes(input_args=InputArgs), + output_locs=[ + OutputLocation(var="x", var_type=FlyteFile, location="{input_args.in_file}.tmp"), + ], + ) + + input_args = InputArgs(FlyteFile(path=test_csv)) + x = t(input_args=input_args) + assert x.path[-4:] == ".tmp" + + def test_shell_script(): t = ShellTask( name="test2", @@ -162,8 +196,8 @@ def test_shell_script(): script_file=script_sh, inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime), output_locs=[ - OutputLocation(var="x", var_type=FlyteDirectory, location="{{ .inputs.y }}"), - OutputLocation(var="y", var_type=FlyteFile, location="{{ .inputs.f }}.pyc"), + OutputLocation(var="x", var_type=FlyteDirectory, location="{y}"), + OutputLocation(var="z", var_type=FlyteFile, location="{f}.pyc"), ], ) diff --git a/tests/flytekit/unit/extras/tasks/testdata/script.sh b/tests/flytekit/unit/extras/tasks/testdata/script.sh index 22012ec3ae0..1deb4c474a0 100644 --- a/tests/flytekit/unit/extras/tasks/testdata/script.sh +++ b/tests/flytekit/unit/extras/tasks/testdata/script.sh @@ -2,5 +2,5 @@ set -ex -cat "{{ .inputs.f }}" >> "{{ .outputs.y }}" -echo "Hello World {{ .inputs.y }} on {{ .inputs.j }} - output {{.outputs.x}}" +cat "{f}" >> "{z}" +echo "Hello World {y} on {j} - output {x}"