Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve literal type string representation handling #5932

Merged
merged 16 commits into from
Nov 1, 2024
23 changes: 23 additions & 0 deletions flytepropeller/pkg/compiler/common/pretty_print.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package common

import (
"fmt"
"strings"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
)

func LiteralTypeToStr(lt *core.LiteralType) string {
if lt == nil {
return "None"
}
if lt.GetSimple() == core.SimpleType_STRUCT {
var structure string
for k, v := range lt.GetStructure().GetDataclassType() {
structure += fmt.Sprintf("dataclass_type:{key:%v value:{%v}, ", k, LiteralTypeToStr(v))
}
structure = strings.TrimSuffix(structure, ", ")
return fmt.Sprintf("simple: STRUCT structure{%v}", structure)
}
return lt.String()
}
37 changes: 37 additions & 0 deletions flytepropeller/pkg/compiler/common/pretty_print_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package common

import (
"testing"

"google.golang.org/protobuf/types/known/structpb"

"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
)

func TestLiteralTypeToStr(t *testing.T) {
dataclassType := &core.LiteralType{
Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT},
Structure: &core.TypeStructure{
DataclassType: map[string]*core.LiteralType{
"a": {
Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER},
},
},
},
Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{
"key": {Kind: &structpb.Value_StringValue{StringValue: "a"}},
}},
}
assert.Equal(t, LiteralTypeToStr(nil), "None")
assert.Equal(t, LiteralTypeToStr(dataclassType), "simple: STRUCT structure{dataclass_type:{key:a value:{simple:INTEGER}}")
assert.NotEqual(t, LiteralTypeToStr(dataclassType), dataclassType.String())

// Test for SimpleType
simpleType := &core.LiteralType{
Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER},
}
assert.Equal(t, LiteralTypeToStr(simpleType), "simple:INTEGER")
assert.Equal(t, LiteralTypeToStr(simpleType), simpleType.String())
}
2 changes: 1 addition & 1 deletion flytepropeller/pkg/compiler/transformers/k8s/inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
continue
}
if !validators.AreTypesCastable(inputType, v.Type) {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String()))
errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, common.LiteralTypeToStr(v.Type), common.LiteralTypeToStr(inputType)))

Check warning on line 45 in flytepropeller/pkg/compiler/transformers/k8s/inputs.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/transformers/k8s/inputs.go#L45

Added line #L45 was not covered by tests
continue
}

Expand Down
6 changes: 3 additions & 3 deletions flytepropeller/pkg/compiler/validators/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
// If the variable has an index. We expect param to be a collection.
if v.Index != nil {
if cType := param.GetType().GetCollectionType(); cType == nil {
errs.Collect(errors.NewMismatchingVariablesErr(nodeID, outputVar, param.Type.String(), inputVar, expectedType.String()))
errs.Collect(errors.NewMismatchingVariablesErr(nodeID, outputVar, c.LiteralTypeToStr(param.Type), inputVar, c.LiteralTypeToStr(expectedType)))

Check warning on line 134 in flytepropeller/pkg/compiler/validators/bindings.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/bindings.go#L134

Added line #L134 was not covered by tests
} else {
sourceType = cType
}
Expand Down Expand Up @@ -164,7 +164,7 @@
return param.GetType(), []c.NodeID{val.Promise.NodeId}, true
}

errs.Collect(errors.NewMismatchingVariablesErr(node.GetId(), outputVar, sourceType.String(), inputVar, expectedType.String()))
errs.Collect(errors.NewMismatchingVariablesErr(node.GetId(), outputVar, c.LiteralTypeToStr(sourceType), inputVar, c.LiteralTypeToStr(expectedType)))
return nil, nil, !errs.HasErrors()
}
}
Expand All @@ -180,7 +180,7 @@
if literalType == nil {
errs.Collect(errors.NewUnrecognizedValueErr(nodeID, reflect.TypeOf(val.Scalar.GetValue()).String()))
} else if validateParamTypes && !AreTypesCastable(literalType, expectedType) {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, nodeParam, literalType.String(), expectedType.String()))
errs.Collect(errors.NewMismatchingTypesErr(nodeID, nodeParam, c.LiteralTypeToStr(literalType), c.LiteralTypeToStr(expectedType)))
}

if expectedType.GetEnumType() != nil {
Expand Down
2 changes: 1 addition & 1 deletion flytepropeller/pkg/compiler/validators/condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
if op1Valid && op2Valid && op1Type != nil && op2Type != nil {
if op1Type.String() != op2Type.String() {
errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "RightValue",
op1Type.String(), op2Type.String()))
c.LiteralTypeToStr(op1Type), c.LiteralTypeToStr(op2Type)))

Check warning on line 47 in flytepropeller/pkg/compiler/validators/condition.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/condition.go#L47

Added line #L47 was not covered by tests
}
}
} else if expr.GetConjunction() != nil {
Expand Down
2 changes: 1 addition & 1 deletion flytepropeller/pkg/compiler/validators/vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
func validateVarType(nodeID c.NodeID, paramName string, param *flyte.Variable,
expectedType *flyte.LiteralType, errs errors.CompileErrors) (ok bool) {
if param.GetType().String() != expectedType.String() {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, paramName, param.GetType().String(), expectedType.String()))
errs.Collect(errors.NewMismatchingTypesErr(nodeID, paramName, c.LiteralTypeToStr(param.GetType()), c.LiteralTypeToStr(expectedType)))

Check warning on line 43 in flytepropeller/pkg/compiler/validators/vars.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/vars.go#L43

Added line #L43 was not covered by tests
}

return !errs.HasErrors()
Expand Down
Loading