From 5592b9fb09a8f67cd0086ef1c623aa1cb1a2030f Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Thu, 19 Sep 2019 16:09:53 -0700 Subject: [PATCH] Optimize Command line template parsing (#15) * Template is more optimal to lazily serialize input parameters to command line * Fixed the usage * return an error in case of group match but no literal --- go/tasks/v1/flytek8s/container_helper.go | 2 +- go/tasks/v1/k8splugins/sidecar.go | 4 +- go/tasks/v1/k8splugins/spark.go | 4 +- go/tasks/v1/utils/template.go | 84 +++++++---------- go/tasks/v1/utils/template_test.go | 111 ++++++++++++++--------- 5 files changed, 108 insertions(+), 97 deletions(-) diff --git a/go/tasks/v1/flytek8s/container_helper.go b/go/tasks/v1/flytek8s/container_helper.go index 5c0f6dbf7..8f0d6597a 100755 --- a/go/tasks/v1/flytek8s/container_helper.go +++ b/go/tasks/v1/flytek8s/container_helper.go @@ -78,7 +78,7 @@ func ToK8sContainer(ctx context.Context, taskCtx types.TaskContext, taskContaine cmdLineArgs := utils.CommandLineTemplateArgs{ Input: inputFile.String(), OutputPrefix: taskCtx.GetDataDir().String(), - Inputs: utils.LiteralMapToTemplateArgs(ctx, inputs), + Inputs: inputs, } modifiedCommand, err := utils.ReplaceTemplateCommandArgs(ctx, taskContainer.GetCommand(), cmdLineArgs) diff --git a/go/tasks/v1/k8splugins/sidecar.go b/go/tasks/v1/k8splugins/sidecar.go index 3f3d7b265..bff90a465 100755 --- a/go/tasks/v1/k8splugins/sidecar.go +++ b/go/tasks/v1/k8splugins/sidecar.go @@ -45,7 +45,7 @@ func validateAndFinalizeContainers( utils.CommandLineTemplateArgs{ Input: taskCtx.GetInputsFile().String(), OutputPrefix: taskCtx.GetDataDir().String(), - Inputs: utils.LiteralMapToTemplateArgs(ctx, inputs), + Inputs: inputs, }) if err != nil { @@ -58,7 +58,7 @@ func validateAndFinalizeContainers( utils.CommandLineTemplateArgs{ Input: taskCtx.GetInputsFile().String(), OutputPrefix: taskCtx.GetDataDir().String(), - Inputs: utils.LiteralMapToTemplateArgs(ctx, inputs), + Inputs: inputs, }) if err != nil { diff --git a/go/tasks/v1/k8splugins/spark.go b/go/tasks/v1/k8splugins/spark.go index 17672cb40..5f9ad35fd 100755 --- a/go/tasks/v1/k8splugins/spark.go +++ b/go/tasks/v1/k8splugins/spark.go @@ -113,12 +113,12 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx types.Tas }, } - modifiedArgs, err := utils.ReplaceTemplateCommandArgs(context.TODO(), + modifiedArgs, err := utils.ReplaceTemplateCommandArgs(ctx, task.GetContainer().GetArgs(), utils.CommandLineTemplateArgs{ Input: taskCtx.GetInputsFile().String(), OutputPrefix: taskCtx.GetDataDir().String(), - Inputs: utils.LiteralMapToTemplateArgs(context.TODO(), inputs), + Inputs: inputs, }) if err != nil { diff --git a/go/tasks/v1/utils/template.go b/go/tasks/v1/utils/template.go index 1d88ec6f0..84a952365 100755 --- a/go/tasks/v1/utils/template.go +++ b/go/tasks/v1/utils/template.go @@ -9,7 +9,7 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flytestdlib/logger" + "github.com/pkg/errors" ) var inputFileRegex = regexp.MustCompile(`(?i){{\s*[\.$]Input\s*}}`) @@ -18,9 +18,9 @@ var inputVarRegex = regexp.MustCompile(`(?i){{\s*[\.$]Inputs\.(?P[^} // Contains arguments passed down to command line templates. type CommandLineTemplateArgs struct { - Input string `json:"input"` - OutputPrefix string `json:"output"` - Inputs map[string]string `json:"inputs"` + Input string `json:"input"` + OutputPrefix string `json:"output"` + Inputs *core.LiteralMap `json:"inputs"` } // Evaluates templates in each command with the equivalent value from passed args. Templates are case-insensitive @@ -48,7 +48,7 @@ func ReplaceTemplateCommandArgs(ctx context.Context, command []string, args Comm return res, nil } -func replaceTemplateCommandArgs(_ context.Context, commandTemplate string, args *CommandLineTemplateArgs) (string, error) { +func replaceTemplateCommandArgs(ctx context.Context, commandTemplate string, args *CommandLineTemplateArgs) (string, error) { val := inputFileRegex.ReplaceAllString(commandTemplate, args.Input) val = outputRegex.ReplaceAllString(val, args.OutputPrefix) groupMatches := inputVarRegex.FindAllStringSubmatchIndex(val, -1) @@ -64,89 +64,75 @@ func replaceTemplateCommandArgs(_ context.Context, commandTemplate string, args inputStartIdx := groupMatches[0][2] inputEndIdx := groupMatches[0][3] inputName := val[inputStartIdx:inputEndIdx] - inputVal, exists := args.Inputs[inputName] + + if args.Inputs == nil || args.Inputs.Literals == nil { + return val, fmt.Errorf("no inputs provided, cannot bind input name [%s]", inputName) + } + inputVal, exists := args.Inputs.Literals[inputName] if !exists { return val, fmt.Errorf("requested input is not found [%v] while processing template [%v]", inputName, commandTemplate) } + v, err := serializeLiteral(ctx, inputVal) + if err != nil { + return val, errors.Wrapf(err, "failed to bind a value to inputName [%s]", inputName) + } if endIdx >= len(val) { - return val[:startIdx] + inputVal, nil + return val[:startIdx] + v, nil } - return val[:startIdx] + inputVal + val[endIdx:], nil - } -} - -// Converts a literal map to a go map that can be used in templates. It drops literals that don't have a defined way to -// be safely serialized into a string. -func LiteralMapToTemplateArgs(ctx context.Context, m *core.LiteralMap) map[string]string { - if m == nil { - return map[string]string{} - } - - res := make(map[string]string, len(m.Literals)) - - for key, val := range m.Literals { - serialized, ok := serializeLiteral(ctx, val) - if ok { - res[key] = serialized - } + return val[:startIdx] + v + val[endIdx:], nil } - - return res } -func serializePrimitive(ctx context.Context, p *core.Primitive) (string, bool) { +func serializePrimitive(p *core.Primitive) (string, error) { switch o := p.Value.(type) { case *core.Primitive_Integer: - return fmt.Sprintf("%v", o.Integer), true + return fmt.Sprintf("%v", o.Integer), nil case *core.Primitive_Boolean: - return fmt.Sprintf("%v", o.Boolean), true + return fmt.Sprintf("%v", o.Boolean), nil case *core.Primitive_Datetime: - return ptypes.TimestampString(o.Datetime), true + return ptypes.TimestampString(o.Datetime), nil case *core.Primitive_Duration: - return o.Duration.String(), true + return o.Duration.String(), nil case *core.Primitive_FloatValue: - return fmt.Sprintf("%v", o.FloatValue), true + return fmt.Sprintf("%v", o.FloatValue), nil case *core.Primitive_StringValue: - return o.StringValue, true + return o.StringValue, nil default: - logger.Warnf(ctx, "Received an unexpected primitive type [%v]", reflect.TypeOf(p.Value)) - return "", false + return "", fmt.Errorf("received an unexpected primitive type [%v]", reflect.TypeOf(p.Value)) } } -func serializeLiteralScalar(ctx context.Context, l *core.Scalar) (string, bool) { +func serializeLiteralScalar(l *core.Scalar) (string, error) { switch o := l.Value.(type) { case *core.Scalar_Primitive: - return serializePrimitive(ctx, o.Primitive) + return serializePrimitive(o.Primitive) case *core.Scalar_Blob: - return o.Blob.Uri, true + return o.Blob.Uri, nil default: - logger.Warnf(ctx, "Received an unexpected scalar type [%v]", reflect.TypeOf(l.Value)) - return "", false + return "", fmt.Errorf("received an unexpected scalar type [%v]", reflect.TypeOf(l.Value)) } } -func serializeLiteral(ctx context.Context, l *core.Literal) (string, bool) { +func serializeLiteral(ctx context.Context, l *core.Literal) (string, error) { switch o := l.Value.(type) { case *core.Literal_Collection: res := make([]string, 0, len(o.Collection.Literals)) for _, sub := range o.Collection.Literals { - s, ok := serializeLiteral(ctx, sub) - if !ok { - return "", false + s, err := serializeLiteral(ctx, sub) + if err != nil { + return "", err } res = append(res, s) } - return fmt.Sprintf("[%v]", strings.Join(res, ",")), true + return fmt.Sprintf("[%v]", strings.Join(res, ",")), nil case *core.Literal_Scalar: - return serializeLiteralScalar(ctx, o.Scalar) + return serializeLiteralScalar(o.Scalar) default: - logger.Warnf(ctx, "Received an unexpected primitive type [%v]", reflect.TypeOf(l.Value)) - return "", false + return "", fmt.Errorf("received an unexpected primitive type [%v]", reflect.TypeOf(l.Value)) } } diff --git a/go/tasks/v1/utils/template_test.go b/go/tasks/v1/utils/template_test.go index 5060927db..b8c2369bc 100755 --- a/go/tasks/v1/utils/template_test.go +++ b/go/tasks/v1/utils/template_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/lyft/flyteidl/clients/go/coreutils" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/stretchr/testify/assert" ) @@ -27,8 +28,10 @@ func BenchmarkReplacements(b *testing.B) { cmdTemplate := `abc {{ index .Inputs "x" }}` cmdArgs := CommandLineTemplateArgs{ Input: "inputfile.pb", - Inputs: map[string]string{ - "x": "1", + Inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "x": coreutils.MustMakePrimitiveLiteral(1), + }, }, } @@ -180,8 +183,16 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { }, CommandLineTemplateArgs{ Input: "input/blah", OutputPrefix: "output/blah", - Inputs: map[string]string{ - "arr": "[a,b]", + Inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "arr": { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{coreutils.MustMakeLiteral("a"), coreutils.MustMakeLiteral("b")}, + }, + }, + }, + }, }}) assert.NoError(t, err) assert.Equal(t, []string{ @@ -191,49 +202,63 @@ func TestReplaceTemplateCommandArgs(t *testing.T) { "output/blah", }, actual) }) -} - -func TestLiteralMapToTemplateArgs(t *testing.T) { - t.Run("Scalars", func(t *testing.T) { - expected := map[string]string{ - "str": "blah", - "int": "5", - "date": "1900-01-01T01:01:01.000000001Z", - } - - dd := time.Date(1900, 1, 1, 1, 1, 1, 1, time.UTC) - lit := coreutils.MustMakeLiteral(map[string]interface{}{ - "str": "blah", - "int": 5, - "date": dd, - }) - - actual := LiteralMapToTemplateArgs(context.TODO(), lit.GetMap()) - assert.Equal(t, expected, actual) + t.Run("Date", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.date }}`, + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{ + Input: "input/blah", + OutputPrefix: "output/blah", + Inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "date": coreutils.MustMakeLiteral(time.Date(1900, 01, 01, 01, 01, 01, 000000001, time.UTC)), + }, + }}) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + "--someArg 1900-01-01T01:01:01.000000001Z", + "output/blah", + }, actual) }) - t.Run("1d array", func(t *testing.T) { - expected := map[string]string{ - "arr": "[a,b]", - } - - actual := LiteralMapToTemplateArgs(context.TODO(), coreutils.MustMakeLiteral(map[string]interface{}{ - "arr": []interface{}{"a", "b"}, - }).GetMap()) - - assert.Equal(t, expected, actual) + t.Run("2d Array arg", func(t *testing.T) { + actual, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.arr }}`, + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{ + Input: "input/blah", + OutputPrefix: "output/blah", + Inputs: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), + }, + }}) + assert.NoError(t, err) + assert.Equal(t, []string{ + "hello", + "world", + "--someArg [[a,b],[1,2]]", + "output/blah", + }, actual) }) - t.Run("2d array", func(t *testing.T) { - expected := map[string]string{ - "arr": "[[a,b],[1,2]]", - } - - actual := LiteralMapToTemplateArgs(context.TODO(), coreutils.MustMakeLiteral(map[string]interface{}{ - "arr": []interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}, - }).GetMap()) - - assert.Equal(t, expected, actual) + t.Run("nil input", func(t *testing.T) { + _, err := ReplaceTemplateCommandArgs(context.TODO(), []string{ + "hello", + "world", + `--someArg {{ .Inputs.arr }}`, + "{{ .OutputPrefix }}", + }, CommandLineTemplateArgs{ + Input: "input/blah", + OutputPrefix: "output/blah", + Inputs: &core.LiteralMap{Literals: nil}}) + assert.Error(t, err) }) }