diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/error_collection.go b/flyteplugins/go/tasks/pluginmachinery/utils/error_collection.go new file mode 100644 index 000000000..f833b994c --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/utils/error_collection.go @@ -0,0 +1,19 @@ +package utils + +import ( + "fmt" + "strings" +) + +type ErrorCollection struct { + Errors []error +} + +func (e ErrorCollection) Error() string { + sb := strings.Builder{} + for idx, err := range e.Errors { + sb.WriteString(fmt.Sprintf("%v: %v\r\n", idx, err)) + } + + return sb.String() +} diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/error_collection_test.go b/flyteplugins/go/tasks/pluginmachinery/utils/error_collection_test.go new file mode 100644 index 000000000..dd5318251 --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/utils/error_collection_test.go @@ -0,0 +1,22 @@ +package utils + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorCollection(t *testing.T) { + ec := ErrorCollection{} + + assert.Empty(t, ec.Error()) + + ec.Errors = append(ec.Errors, fmt.Errorf("error1")) + assert.NotEmpty(t, ec.Error()) + + ec.Errors = append(ec.Errors, fmt.Errorf("error2")) + assert.NotEmpty(t, ec.Error()) + + assert.Equal(t, "0: error1\r\n1: error2\r\n", ec.Error()) +} diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/template.go b/flyteplugins/go/tasks/pluginmachinery/utils/template.go index 5f1a777d1..087ced2f6 100755 --- a/flyteplugins/go/tasks/pluginmachinery/utils/template.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/template.go @@ -51,47 +51,49 @@ func ReplaceTemplateCommandArgs(ctx context.Context, command []string, in io.Inp return res, nil } +func transformVarNameToStringVal(ctx context.Context, varName string, inputs *core.LiteralMap) (string, error) { + inputVal, exists := inputs.Literals[varName] + if !exists { + return "", fmt.Errorf("requested input is not found [%s]", varName) + } + + v, err := serializeLiteral(ctx, inputVal) + if err != nil { + return "", errors.Wrapf(err, "failed to bind a value to inputName [%s]", varName) + } + return v, nil +} + func replaceTemplateCommandArgs(ctx context.Context, commandTemplate string, in io.InputReader, out io.OutputFilePaths) (string, error) { val := inputFileRegex.ReplaceAllString(commandTemplate, in.GetInputPath().String()) val = outputRegex.ReplaceAllString(val, out.GetOutputPrefixPath().String()) val = inputPrefixRegex.ReplaceAllString(val, in.GetInputPrefixPath().String()) - groupMatches := inputVarRegex.FindAllStringSubmatchIndex(val, -1) - if len(groupMatches) == 0 { - return val, nil - } else if len(groupMatches) > 1 { - return val, fmt.Errorf("only one level of inputs nesting is supported. Syntax in [%v] is invalid", commandTemplate) - } else if len(groupMatches[0]) > 4 { - return val, fmt.Errorf("longer submatches not supported. Syntax in [%v] is invalid", commandTemplate) - } - startIdx := groupMatches[0][0] - endIdx := groupMatches[0][1] - inputStartIdx := groupMatches[0][2] - inputEndIdx := groupMatches[0][3] - inputName := val[inputStartIdx:inputEndIdx] inputs, err := in.Get(ctx) if err != nil { - return val, errors.Wrapf(err, "unable to read inputs for [%s]", inputName) + return val, errors.Wrapf(err, "unable to read inputs") } if inputs == nil || inputs.Literals == nil { - return val, fmt.Errorf("no inputs provided, cannot bind input name [%s]", inputName) - } - inputVal, exists := inputs.Literals[inputName] - if !exists { - return val, fmt.Errorf("requested input is not found [%v] while processing template [%v]", - inputName, commandTemplate) + return val, nil } - v, err := serializeLiteral(ctx, inputVal) - if err != nil { - return val, errors.Wrapf(err, "failed to bind a value to inputName [%s]", inputName) - } - if endIdx >= len(val) { - return val[:startIdx] + v, nil - } + var errs ErrorCollection + val = inputVarRegex.ReplaceAllStringFunc(val, func(s string) string { + matches := inputVarRegex.FindAllStringSubmatch(s, 1) + varName := matches[0][1] + replaced, err := transformVarNameToStringVal(ctx, varName, inputs) + if err != nil { + errs.Errors = append(errs.Errors, errors.Wrapf(err, "input template [%s]", s)) + return "" + } + return replaced + }) - return val[:startIdx] + v + val[endIdx:], nil + if len(errs.Errors) > 0 { + return "", errs + } + return val, nil } func serializePrimitive(p *core.Primitive) (string, error) { diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/template_test.go b/flyteplugins/go/tasks/pluginmachinery/utils/template_test.go index 6b93cb545..07ee1e3fa 100755 --- a/flyteplugins/go/tasks/pluginmachinery/utils/template_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/template_test.go @@ -260,12 +260,82 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { t.Run("nil input", func(t *testing.T) { in := dummyInputReader{inputs: &core.LiteralMap{}} - _, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.arr }}`, + "{{ .OutputPrefix }}", + }, in, out) + assert.NoError(t, err) + assert.Equal(t, []string{ "hello", "world", `--someArg {{ .Inputs.arr }}`, + "output/blah", + }, actual) + }) + + t.Run("multi-input", func(t *testing.T) { + in := dummyInputReader{inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "ds": coreutils.MustMakeLiteral(time.Date(1900, 01, 01, 01, 01, 01, 000000001, time.UTC)), + "table": coreutils.MustMakeLiteral("my_table"), + "hr": coreutils.MustMakeLiteral("hr"), + "min": coreutils.MustMakeLiteral(15), + }, + }} + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + `SELECT + COUNT(*) as total_count + FROM + hive.events.{{ .Inputs.table }} + WHERE + ds = '{{ .Inputs.ds }}' AND hr = '{{ .Inputs.hr }}' AND min = {{ .Inputs.min }} + `}, in, out) + assert.NoError(t, err) + assert.Equal(t, []string{ + `SELECT + COUNT(*) as total_count + FROM + hive.events.my_table + WHERE + ds = '1900-01-01T01:01:01.000000001Z' AND hr = 'hr' AND min = 15 + `}, actual) + }) + + t.Run("missing input", func(t *testing.T) { + in := dummyInputReader{inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), + }, + }} + _, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.blah }}`, "{{ .OutputPrefix }}", }, in, out) assert.Error(t, err) }) + + t.Run("bad template", func(t *testing.T) { + in := dummyInputReader{inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), + }, + }} + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.blah blah }}`, + "{{ .OutputPrefix }}", + }, in, out) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + `--someArg {{ .Inputs.blah blah }}`, + "output/blah", + }, actual) + }) } diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/transformers_test.go b/flyteplugins/go/tasks/pluginmachinery/utils/transformers_test.go index f16732901..17047211e 100755 --- a/flyteplugins/go/tasks/pluginmachinery/utils/transformers_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/transformers_test.go @@ -16,3 +16,11 @@ func TestContains(t *testing.T) { assert.False(t, Contains(nil, "b")) } + +func TestCopyMap(t *testing.T) { + assert.Nil(t, CopyMap(nil)) + m := map[string]string{ + "l": "v", + } + assert.Equal(t, m, CopyMap(m)) +} diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/launcher_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/launcher_test.go index b35fa24d0..159c2759a 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/launcher_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/launcher_test.go @@ -82,6 +82,7 @@ func TestLaunchSubTasks(t *testing.T) { ir := &mocks3.InputReader{} ir.OnGetInputPrefixPath().Return("/prefix/") ir.OnGetInputPath().Return("/prefix/inputs.pb") + ir.OnGetMatch(mock.Anything).Return(nil, nil) tCtx := &mocks.TaskExecutionContext{} tCtx.OnTaskReader().Return(tr) diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go index 6894f64d0..10fe7f92b 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/transformer_test.go @@ -168,6 +168,7 @@ func TestArrayJobToBatchInput(t *testing.T) { ir := &mocks2.InputReader{} ir.OnGetInputPath().Return("inputs.pb") ir.OnGetInputPrefixPath().Return("/inputs/prefix") + ir.OnGetMatch(mock.Anything).Return(nil, nil) or := &mocks2.OutputWriter{} or.OnGetOutputPrefixPath().Return("/path/output") diff --git a/flyteplugins/tests/end_to_end.go b/flyteplugins/tests/end_to_end.go index e29b8fee5..72589558e 100644 --- a/flyteplugins/tests/end_to_end.go +++ b/flyteplugins/tests/end_to_end.go @@ -83,6 +83,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i inputReader := &ioMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return(basePrefix) inputReader.OnGetInputPath().Return(basePrefix + "/inputs.pb") + inputReader.OnGetMatch(mock.Anything).Return(inputs, nil) outputWriter := &ioMocks.OutputWriter{} outputWriter.OnGetRawOutputPrefix().Return("/sandbox/")