Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Support union and none type in flyteidl (#401)
Browse files Browse the repository at this point in the history
* add support for Union Scalar

Signed-off-by: Yubo Wang <[email protected]>

* support union type and literals

Signed-off-by: Yubo Wang <[email protected]>

* change union type extraction

Signed-off-by: Yubo Wang <[email protected]>

---------

Signed-off-by: Yubo Wang <[email protected]>
Co-authored-by: Yubo Wang <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
3 people authored May 9, 2023
1 parent 44ae401 commit 9247ff0
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 3 deletions.
9 changes: 9 additions & 0 deletions clients/go/coreutils/extract_literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ func ExtractFromLiteral(literal *core.Literal) (interface{}, error) {
return scalarValue.Generic, nil
case *core.Scalar_StructuredDataset:
return scalarValue.StructuredDataset.Uri, nil
case *core.Scalar_Union:
// extract the value of the union but not the actual union object
extractedVal, err := ExtractFromLiteral(scalarValue.Union.Value)
if err != nil {
return nil, err
}
return extractedVal, nil
case *core.Scalar_NoneType:
return nil, nil
default:
return nil, fmt.Errorf("unsupported literal scalar type %T", scalarValue)
}
Expand Down
40 changes: 39 additions & 1 deletion clients/go/coreutils/extract_literal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestFetchLiteral(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, p.GetScalar())
_, err = ExtractFromLiteral(p)
assert.NotNil(t, err)
assert.Nil(t, err)
})

t.Run("Generic", func(t *testing.T) {
Expand Down Expand Up @@ -199,4 +199,42 @@ func TestFetchLiteral(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Union", func(t *testing.T) {
literalVal := int64(1)
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
}
lit, err := MakeLiteralForType(literalType, literalVal)
assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Equal(t, literalVal, extractedLiteralVal)
})

t.Run("Union with None", func(t *testing.T) {
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}},
},
},
},
}
lit, err := MakeLiteralForType(literalType, nil)

assert.NoError(t, err)
extractedLiteralVal, err := ExtractFromLiteral(lit)
assert.NoError(t, err)
assert.Nil(t, extractedLiteralVal)
})
}
48 changes: 48 additions & 0 deletions clients/go/coreutils/literals.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,28 @@ func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) {
return MakeLiteralForType(typ, nil)
case *core.LiteralType_Schema:
return MakeLiteralForType(typ, nil)
case *core.LiteralType_UnionType:
if len(t.UnionType.Variants) == 0 {
return nil, errors.Errorf("Union type must have at least one variant")
}
// For union types, we just return the default for the first variant
val, err := MakeDefaultLiteralForType(t.UnionType.Variants[0])
if err != nil {
return nil, errors.Errorf("Failed to create default literal for first union type variant [%v]", t.UnionType.Variants[0])
}
res := &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Type: t.UnionType.Variants[0],
Value: val,
},
},
},
},
}
return res, nil
}

return nil, fmt.Errorf("failed to convert to a known Literal. Input Type [%v] not supported", typ.String())
Expand Down Expand Up @@ -588,6 +610,32 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro
}
return MakePrimitiveLiteral(newV)

case *core.LiteralType_UnionType:
// Try different types in the variants, return the first one matched
found := false
for _, subType := range newT.UnionType.Variants {
lv, err := MakeLiteralForType(subType, v)
if err == nil {
l = &core.Literal{
Value: &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Value: lv,
Type: subType,
},
},
},
},
}
found = true
break
}
}
if !found {
return nil, fmt.Errorf("incorrect union value [%s], supported values %+v", v, newT.UnionType.Variants)
}

default:
return nil, fmt.Errorf("unsupported type %s", t.String())
}
Expand Down
47 changes: 47 additions & 0 deletions clients/go/coreutils/literals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,23 @@ func TestMakeDefaultLiteralForType(t *testing.T) {
Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "x"}}}}}}
assert.Equal(t, expected, l)
})

t.Run("union", func(t *testing.T) {
l, err := MakeDefaultLiteralForType(
&core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
},
)
assert.NoError(t, err)
assert.Equal(t, "*core.Union", reflect.TypeOf(l.GetScalar().GetUnion()).String())
})
}

func TestMustMakeDefaultLiteralForType(t *testing.T) {
Expand Down Expand Up @@ -715,4 +732,34 @@ func TestMakeLiteralForType(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, expectedVal, actualVal)
})

t.Run("Union", func(t *testing.T) {
var literalType = &core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}},
{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
},
},
},
}
expectedLV := &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}},
Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_FloatValue{FloatValue: 0.1}}}}}},
},
},
}}}
lv, err := MakeLiteralForType(literalType, float64(0.1))
assert.NoError(t, err)
assert.Equal(t, expectedLV, lv)
expectedVal, err := ExtractFromLiteral(expectedLV)
assert.NoError(t, err)
actualVal, err := ExtractFromLiteral(lv)
assert.NoError(t, err)
assert.Equal(t, expectedVal, actualVal)
})
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ require (
k8s.io/klog/v2 v2.5.0 // indirect
)

// These 2 versions were wrongly published.
// These 2 versions were wrongly published.
retract (
v1.4.0
v1.4.2
v1.4.0
)

0 comments on commit 9247ff0

Please sign in to comment.