From a2988ba35c8c5c11616bdccf7018b9255fa8e6e7 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 31 Oct 2024 17:29:59 -0700 Subject: [PATCH] Improve literal type string representation handling (#5932) Signed-off-by: Kevin Su --- .../pkg/compiler/common/pretty_print.go | 23 ++++++++++++ .../pkg/compiler/common/pretty_print_test.go | 36 +++++++++++++++++++ .../pkg/compiler/transformers/k8s/inputs.go | 2 +- .../pkg/compiler/validators/bindings.go | 6 ++-- .../pkg/compiler/validators/condition.go | 2 +- .../pkg/compiler/validators/vars.go | 2 +- 6 files changed, 65 insertions(+), 6 deletions(-) create mode 100644 flytepropeller/pkg/compiler/common/pretty_print.go create mode 100644 flytepropeller/pkg/compiler/common/pretty_print_test.go diff --git a/flytepropeller/pkg/compiler/common/pretty_print.go b/flytepropeller/pkg/compiler/common/pretty_print.go new file mode 100644 index 0000000000..61df408a4e --- /dev/null +++ b/flytepropeller/pkg/compiler/common/pretty_print.go @@ -0,0 +1,23 @@ +package common + +import ( + "fmt" + "strings" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" +) + +func LiteralTypeToStr(lt *core.LiteralType) string { + if lt == nil { + return "None" + } + if lt.GetSimple() == core.SimpleType_STRUCT { + var structure string + for k, v := range lt.GetStructure().GetDataclassType() { + structure += fmt.Sprintf("dataclass_type:{key:%v value:{%v}, ", k, LiteralTypeToStr(v)) + } + structure = strings.TrimSuffix(structure, ", ") + return fmt.Sprintf("simple: STRUCT structure{%v}", structure) + } + return lt.String() +} diff --git a/flytepropeller/pkg/compiler/common/pretty_print_test.go b/flytepropeller/pkg/compiler/common/pretty_print_test.go new file mode 100644 index 0000000000..2d875af5dd --- /dev/null +++ b/flytepropeller/pkg/compiler/common/pretty_print_test.go @@ -0,0 +1,36 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" +) + +func TestLiteralTypeToStr(t *testing.T) { + dataclassType := &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}, + Structure: &core.TypeStructure{ + DataclassType: map[string]*core.LiteralType{ + "a": { + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + }, + }, + }, + Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{ + "key": {Kind: &structpb.Value_StringValue{StringValue: "a"}}, + }}, + } + assert.Equal(t, LiteralTypeToStr(nil), "None") + assert.Equal(t, LiteralTypeToStr(dataclassType), "simple: STRUCT structure{dataclass_type:{key:a value:{simple:INTEGER}}") + assert.NotEqual(t, LiteralTypeToStr(dataclassType), dataclassType.String()) + + // Test for SimpleType + simpleType := &core.LiteralType{ + Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}, + } + assert.Equal(t, LiteralTypeToStr(simpleType), "simple:INTEGER") + assert.Equal(t, LiteralTypeToStr(simpleType), simpleType.String()) +} diff --git a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go index 26f50d4ddd..2b94570c20 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go @@ -42,7 +42,7 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor continue } if !validators.AreTypesCastable(inputType, v.Type) { - errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String())) + errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, common.LiteralTypeToStr(v.Type), common.LiteralTypeToStr(inputType))) continue } diff --git a/flytepropeller/pkg/compiler/validators/bindings.go b/flytepropeller/pkg/compiler/validators/bindings.go index 53535ba260..b69dda529f 100644 --- a/flytepropeller/pkg/compiler/validators/bindings.go +++ b/flytepropeller/pkg/compiler/validators/bindings.go @@ -131,7 +131,7 @@ func validateBinding(w c.WorkflowBuilder, node c.Node, nodeParam string, binding // If the variable has an index. We expect param to be a collection. if v.Index != nil { if cType := param.GetType().GetCollectionType(); cType == nil { - errs.Collect(errors.NewMismatchingVariablesErr(nodeID, outputVar, param.Type.String(), inputVar, expectedType.String())) + errs.Collect(errors.NewMismatchingVariablesErr(nodeID, outputVar, c.LiteralTypeToStr(param.Type), inputVar, c.LiteralTypeToStr(expectedType))) } else { sourceType = cType } @@ -164,7 +164,7 @@ func validateBinding(w c.WorkflowBuilder, node c.Node, nodeParam string, binding return param.GetType(), []c.NodeID{val.Promise.NodeId}, true } - errs.Collect(errors.NewMismatchingVariablesErr(node.GetId(), outputVar, sourceType.String(), inputVar, expectedType.String())) + errs.Collect(errors.NewMismatchingVariablesErr(node.GetId(), outputVar, c.LiteralTypeToStr(sourceType), inputVar, c.LiteralTypeToStr(expectedType))) return nil, nil, !errs.HasErrors() } } @@ -180,7 +180,7 @@ func validateBinding(w c.WorkflowBuilder, node c.Node, nodeParam string, binding if literalType == nil { errs.Collect(errors.NewUnrecognizedValueErr(nodeID, reflect.TypeOf(val.Scalar.GetValue()).String())) } else if validateParamTypes && !AreTypesCastable(literalType, expectedType) { - errs.Collect(errors.NewMismatchingTypesErr(nodeID, nodeParam, literalType.String(), expectedType.String())) + errs.Collect(errors.NewMismatchingTypesErr(nodeID, nodeParam, c.LiteralTypeToStr(literalType), c.LiteralTypeToStr(expectedType))) } if expectedType.GetEnumType() != nil { diff --git a/flytepropeller/pkg/compiler/validators/condition.go b/flytepropeller/pkg/compiler/validators/condition.go index 8e202b6423..70b72cde8a 100644 --- a/flytepropeller/pkg/compiler/validators/condition.go +++ b/flytepropeller/pkg/compiler/validators/condition.go @@ -44,7 +44,7 @@ func ValidateBooleanExpression(w c.WorkflowBuilder, node c.NodeBuilder, expr *fl if op1Valid && op2Valid && op1Type != nil && op2Type != nil { if op1Type.String() != op2Type.String() { errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "RightValue", - op1Type.String(), op2Type.String())) + c.LiteralTypeToStr(op1Type), c.LiteralTypeToStr(op2Type))) } } } else if expr.GetConjunction() != nil { diff --git a/flytepropeller/pkg/compiler/validators/vars.go b/flytepropeller/pkg/compiler/validators/vars.go index 53ca67e4ee..e114dc4fc0 100644 --- a/flytepropeller/pkg/compiler/validators/vars.go +++ b/flytepropeller/pkg/compiler/validators/vars.go @@ -40,7 +40,7 @@ func validateInputVar(n c.NodeBuilder, paramName string, requireParamType bool, func validateVarType(nodeID c.NodeID, paramName string, param *flyte.Variable, expectedType *flyte.LiteralType, errs errors.CompileErrors) (ok bool) { if param.GetType().String() != expectedType.String() { - errs.Collect(errors.NewMismatchingTypesErr(nodeID, paramName, param.GetType().String(), expectedType.String())) + errs.Collect(errors.NewMismatchingTypesErr(nodeID, paramName, c.LiteralTypeToStr(param.GetType()), c.LiteralTypeToStr(expectedType))) } return !errs.HasErrors()