From bc5350fb949af7f6c49640502ad0e93a0171b834 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss <77798312+pmahindrakar-oss@users.noreply.github.com> Date: Fri, 28 May 2021 11:22:04 +0530 Subject: [PATCH] Removed literals.go from flyteplugins and reusing it from flyteidl (#177) --- go/tasks/pluginmachinery/utils/literals.go | 439 ------------------ .../pluginmachinery/utils/literals_test.go | 358 -------------- .../k8s/sagemaker/plugin_test_utils.go | 17 +- 3 files changed, 9 insertions(+), 805 deletions(-) delete mode 100644 go/tasks/pluginmachinery/utils/literals.go delete mode 100644 go/tasks/pluginmachinery/utils/literals_test.go diff --git a/go/tasks/pluginmachinery/utils/literals.go b/go/tasks/pluginmachinery/utils/literals.go deleted file mode 100644 index 6302d4a45e..0000000000 --- a/go/tasks/pluginmachinery/utils/literals.go +++ /dev/null @@ -1,439 +0,0 @@ -package utils - -import ( - "fmt" - "reflect" - "strconv" - "strings" - "time" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/storage" - "github.com/golang/protobuf/jsonpb" - "github.com/golang/protobuf/ptypes" - structpb "github.com/golang/protobuf/ptypes/struct" - "github.com/pkg/errors" -) - -func MakePrimitive(v interface{}) (*core.Primitive, error) { - switch p := v.(type) { - case int: - return &core.Primitive{ - Value: &core.Primitive_Integer{ - Integer: int64(p), - }, - }, nil - case int64: - return &core.Primitive{ - Value: &core.Primitive_Integer{ - Integer: p, - }, - }, nil - case float64: - return &core.Primitive{ - Value: &core.Primitive_FloatValue{ - FloatValue: p, - }, - }, nil - case time.Time: - t, err := ptypes.TimestampProto(p) - if err != nil { - return nil, err - } - return &core.Primitive{ - Value: &core.Primitive_Datetime{ - Datetime: t, - }, - }, nil - case time.Duration: - d := ptypes.DurationProto(p) - return &core.Primitive{ - Value: &core.Primitive_Duration{ - Duration: d, - }, - }, nil - case string: - return &core.Primitive{ - Value: &core.Primitive_StringValue{ - StringValue: p, - }, - }, nil - case bool: - return &core.Primitive{ - Value: &core.Primitive_Boolean{ - Boolean: p, - }, - }, nil - } - return nil, errors.Errorf("Failed to convert to a known primitive type. Input Type [%v] not supported", reflect.TypeOf(v).String()) -} - -func MustMakePrimitive(v interface{}) *core.Primitive { - f, err := MakePrimitive(v) - if err != nil { - panic(err) - } - return f -} - -func MakePrimitiveLiteral(v interface{}) (*core.Literal, error) { - p, err := MakePrimitive(v) - if err != nil { - return nil, err - } - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Primitive{ - Primitive: p, - }, - }, - }, - }, nil -} - -func MustMakePrimitiveLiteral(v interface{}) *core.Literal { - p, err := MakePrimitiveLiteral(v) - if err != nil { - panic(err) - } - return p -} - -func MakeLiteralMap(v map[string]interface{}) (*core.LiteralMap, error) { - - literals := make(map[string]*core.Literal, len(v)) - for key, val := range v { - l, err := MakeLiteral(val) - if err != nil { - return nil, err - } - - literals[key] = l - } - - return &core.LiteralMap{ - Literals: literals, - }, nil -} - -func MakeLiteralForMap(v map[string]interface{}) (*core.Literal, error) { - m, err := MakeLiteralMap(v) - if err != nil { - return nil, err - } - - return &core.Literal{ - Value: &core.Literal_Map{ - Map: m, - }, - }, nil -} - -func MakeLiteralForCollection(v []interface{}) (*core.Literal, error) { - literals := make([]*core.Literal, 0, len(v)) - for _, val := range v { - l, err := MakeLiteral(val) - if err != nil { - return nil, err - } - - literals = append(literals, l) - } - - return &core.Literal{ - Value: &core.Literal_Collection{ - Collection: &core.LiteralCollection{ - Literals: literals, - }, - }, - }, nil -} - -func MakeBinaryLiteral(v []byte) *core.Literal { - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: v, - }, - }, - }, - }, - } -} - -func MakeGenericLiteral(v *structpb.Struct) *core.Literal { - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Generic{ - Generic: v, - }, - }, - }} -} - -func MakeLiteral(v interface{}) (*core.Literal, error) { - if v == nil { - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_NoneType{ - NoneType: &core.Void{}, - }, - }, - }, - }, nil - } - switch o := v.(type) { - case *core.Literal: - return o, nil - case []interface{}: - return MakeLiteralForCollection(o) - case map[string]interface{}: - return MakeLiteralForMap(o) - case []byte: - return MakeBinaryLiteral(v.([]byte)), nil - case *structpb.Struct: - return MakeGenericLiteral(v.(*structpb.Struct)), nil - case *core.Error: - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Error{ - Error: v.(*core.Error), - }, - }, - }, - }, nil - default: - return MakePrimitiveLiteral(o) - } -} - -func MustMakeLiteral(v interface{}) *core.Literal { - p, err := MakeLiteral(v) - if err != nil { - panic(err) - } - - return p -} - -func MustMakeDefaultLiteralForType(typ *core.LiteralType) *core.Literal { - if res, err := MakeDefaultLiteralForType(typ); err != nil { - panic(err) - } else { - return res - } -} - -func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) { - if typ != nil { - switch t := typ.GetType().(type) { - case *core.LiteralType_Simple: - switch t.Simple { - case core.SimpleType_NONE: - return MakeLiteral(nil) - case core.SimpleType_INTEGER: - return MakeLiteral(int(0)) - case core.SimpleType_FLOAT: - return MakeLiteral(float64(0)) - case core.SimpleType_STRING: - return MakeLiteral("") - case core.SimpleType_BOOLEAN: - return MakeLiteral(false) - case core.SimpleType_DATETIME: - return MakeLiteral(time.Now()) - case core.SimpleType_DURATION: - return MakeLiteral(time.Second) - case core.SimpleType_BINARY: - return MakeLiteral([]byte{}) - case core.SimpleType_ERROR: - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Error{ - Error: &core.Error{ - Message: "Default Error message", - }, - }, - }, - }, - }, nil - case core.SimpleType_STRUCT: - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Generic{ - Generic: &structpb.Struct{}, - }, - }, - }, - }, nil - } - return nil, errors.Errorf("Not yet implemented. Default creation is not yet implemented for [%s] ", t.Simple.String()) - - case *core.LiteralType_Blob: - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Blob{ - Blob: &core.Blob{ - Metadata: &core.BlobMetadata{ - Type: t.Blob, - }, - Uri: "/tmp/somepath", - }, - }, - }, - }, - }, nil - case *core.LiteralType_CollectionType: - single, err := MakeDefaultLiteralForType(t.CollectionType) - if err != nil { - return nil, err - } - - return &core.Literal{ - Value: &core.Literal_Collection{ - Collection: &core.LiteralCollection{ - Literals: []*core.Literal{single}, - }, - }, - }, nil - case *core.LiteralType_MapValueType: - single, err := MakeDefaultLiteralForType(t.MapValueType) - if err != nil { - return nil, err - } - - return &core.Literal{ - Value: &core.Literal_Map{ - Map: &core.LiteralMap{ - Literals: map[string]*core.Literal{ - "itemKey": single, - }, - }, - }, - }, nil - // case *core.LiteralType_Schema: - } - } - return nil, errors.Errorf("Failed to convert to a known Literal. Input Type [%v] not supported", typ.String()) -} - -func MakePrimitiveForType(t core.SimpleType, s string) (*core.Primitive, error) { - p := &core.Primitive{} - switch t { - case core.SimpleType_INTEGER: - v, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return nil, errors.Wrap(err, "failed to parse integer value") - } - p.Value = &core.Primitive_Integer{Integer: v} - case core.SimpleType_FLOAT: - v, err := strconv.ParseFloat(s, 64) - if err != nil { - return nil, errors.Wrap(err, "failed to parse Float value") - } - p.Value = &core.Primitive_FloatValue{FloatValue: v} - case core.SimpleType_BOOLEAN: - v, err := strconv.ParseBool(s) - if err != nil { - return nil, errors.Wrap(err, "failed to parse Bool value") - } - p.Value = &core.Primitive_Boolean{Boolean: v} - case core.SimpleType_STRING: - p.Value = &core.Primitive_StringValue{StringValue: s} - case core.SimpleType_DURATION: - v, err := time.ParseDuration(s) - if err != nil { - return nil, errors.Wrap(err, "failed to parse Duration, valid formats: e.g. 300ms, -1.5h, 2h45m") - } - p.Value = &core.Primitive_Duration{Duration: ptypes.DurationProto(v)} - case core.SimpleType_DATETIME: - v, err := time.Parse(time.RFC3339, s) - if err != nil { - return nil, errors.Wrap(err, "failed to parse Datetime in RFC3339 format") - } - ts, err := ptypes.TimestampProto(v) - if err != nil { - return nil, errors.Wrap(err, "failed to convert datetime to proto") - } - p.Value = &core.Primitive_Datetime{Datetime: ts} - default: - return nil, fmt.Errorf("unsupported type %s", t.String()) - } - return p, nil -} - -func MakeLiteralForSimpleType(t core.SimpleType, s string) (*core.Literal, error) { - s = strings.Trim(s, " \n\t") - scalar := &core.Scalar{} - switch t { - case core.SimpleType_STRUCT: - st := &structpb.Struct{} - err := jsonpb.UnmarshalString(s, st) - if err != nil { - return nil, errors.Wrapf(err, "failed to load generic type as json.") - } - scalar.Value = &core.Scalar_Generic{ - Generic: st, - } - case core.SimpleType_BINARY: - scalar.Value = &core.Scalar_Binary{ - Binary: &core.Binary{ - Value: []byte(s), - // TODO Tag not supported at the moment - }, - } - case core.SimpleType_ERROR: - scalar.Value = &core.Scalar_Error{ - Error: &core.Error{ - Message: s, - }, - } - case core.SimpleType_NONE: - scalar.Value = &core.Scalar_NoneType{ - NoneType: &core.Void{}, - } - default: - p, err := MakePrimitiveForType(t, s) - if err != nil { - return nil, err - } - scalar.Value = &core.Scalar_Primitive{Primitive: p} - } - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: scalar, - }, - }, nil -} - -func MakeLiteralForBlob(path storage.DataReference, isDir bool, format string) *core.Literal { - dim := core.BlobType_SINGLE - if isDir { - dim = core.BlobType_MULTIPART - } - return &core.Literal{ - Value: &core.Literal_Scalar{ - Scalar: &core.Scalar{ - Value: &core.Scalar_Blob{ - Blob: &core.Blob{ - Uri: path.String(), - Metadata: &core.BlobMetadata{ - Type: &core.BlobType{ - Dimensionality: dim, - Format: format, - }, - }, - }, - }, - }, - }, - } -} diff --git a/go/tasks/pluginmachinery/utils/literals_test.go b/go/tasks/pluginmachinery/utils/literals_test.go deleted file mode 100644 index aeabe53e08..0000000000 --- a/go/tasks/pluginmachinery/utils/literals_test.go +++ /dev/null @@ -1,358 +0,0 @@ -package utils - -import ( - "reflect" - "testing" - "time" - - "github.com/go-test/deep" - - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flytestdlib/storage" - "github.com/golang/protobuf/ptypes" - structpb "github.com/golang/protobuf/ptypes/struct" - "github.com/stretchr/testify/assert" -) - -func TestMakePrimitive(t *testing.T) { - { - v := 1 - p, err := MakePrimitive(v) - assert.NoError(t, err) - assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.Value).String()) - assert.Equal(t, int64(v), p.GetInteger()) - } - { - v := int64(1) - p, err := MakePrimitive(v) - assert.NoError(t, err) - assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(p.Value).String()) - assert.Equal(t, v, p.GetInteger()) - } - { - v := 1.0 - p, err := MakePrimitive(v) - assert.NoError(t, err) - assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.Value).String()) - assert.Equal(t, v, p.GetFloatValue()) - } - { - v := "blah" - p, err := MakePrimitive(v) - assert.NoError(t, err) - assert.Equal(t, "*core.Primitive_StringValue", reflect.TypeOf(p.Value).String()) - assert.Equal(t, v, p.GetStringValue()) - } - { - v := true - p, err := MakePrimitive(v) - assert.NoError(t, err) - assert.Equal(t, "*core.Primitive_Boolean", reflect.TypeOf(p.Value).String()) - assert.Equal(t, v, p.GetBoolean()) - } - { - v := time.Now() - p, err := MakePrimitive(v) - assert.NoError(t, err) - assert.Equal(t, "*core.Primitive_Datetime", reflect.TypeOf(p.Value).String()) - j, err := ptypes.TimestampProto(v) - assert.NoError(t, err) - assert.Equal(t, j, p.GetDatetime()) - _, err = MakePrimitive(time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC)) - assert.Error(t, err) - } - { - v := time.Second * 10 - p, err := MakePrimitive(v) - assert.NoError(t, err) - assert.Equal(t, "*core.Primitive_Duration", reflect.TypeOf(p.Value).String()) - assert.Equal(t, ptypes.DurationProto(v), p.GetDuration()) - } - { - v := struct { - }{} - _, err := MakePrimitive(v) - assert.Error(t, err) - } -} - -func TestMustMakePrimitive(t *testing.T) { - { - v := struct { - }{} - assert.Panics(t, func() { - MustMakePrimitive(v) - }) - } - { - v := time.Second * 10 - p := MustMakePrimitive(v) - assert.Equal(t, "*core.Primitive_Duration", reflect.TypeOf(p.Value).String()) - assert.Equal(t, ptypes.DurationProto(v), p.GetDuration()) - } -} - -func TestMakePrimitiveLiteral(t *testing.T) { - { - v := 1.0 - p, err := MakePrimitiveLiteral(v) - assert.NoError(t, err) - assert.NotNil(t, p.GetScalar()) - assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) - assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) - } - { - v := struct { - }{} - _, err := MakePrimitiveLiteral(v) - assert.Error(t, err) - } -} - -func TestMustMakePrimitiveLiteral(t *testing.T) { - t.Run("Panic", func(t *testing.T) { - v := struct { - }{} - assert.Panics(t, func() { - MustMakePrimitiveLiteral(v) - }) - }) - t.Run("FloatValue", func(t *testing.T) { - v := 1.0 - p := MustMakePrimitiveLiteral(v) - assert.NotNil(t, p.GetScalar()) - assert.Equal(t, "*core.Primitive_FloatValue", reflect.TypeOf(p.GetScalar().GetPrimitive().Value).String()) - assert.Equal(t, v, p.GetScalar().GetPrimitive().GetFloatValue()) - }) -} - -func TestMakeLiteral(t *testing.T) { - t.Run("Primitive", func(t *testing.T) { - lit, err := MakeLiteral("test_string") - assert.NoError(t, err) - assert.Equal(t, "*core.Primitive_StringValue", reflect.TypeOf(lit.GetScalar().GetPrimitive().Value).String()) - }) - - t.Run("Array", func(t *testing.T) { - lit, err := MakeLiteral([]interface{}{1, 2, 3}) - assert.NoError(t, err) - assert.Equal(t, "*core.Literal_Collection", reflect.TypeOf(lit.GetValue()).String()) - assert.Equal(t, "*core.Primitive_Integer", reflect.TypeOf(lit.GetCollection().Literals[0].GetScalar().GetPrimitive().Value).String()) - }) - - t.Run("Map", func(t *testing.T) { - lit, err := MakeLiteral(map[string]interface{}{ - "key1": []interface{}{1, 2, 3}, - "key2": []interface{}{5}, - }) - assert.NoError(t, err) - assert.Equal(t, "*core.Literal_Map", reflect.TypeOf(lit.GetValue()).String()) - assert.Equal(t, "*core.Literal_Collection", reflect.TypeOf(lit.GetMap().Literals["key1"].GetValue()).String()) - }) - - t.Run("Binary", func(t *testing.T) { - s := MakeBinaryLiteral([]byte{'h'}) - assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) - }) - - t.Run("NoneType", func(t *testing.T) { - p, err := MakeLiteral(nil) - assert.NoError(t, err) - assert.NotNil(t, p.GetScalar()) - assert.Equal(t, "*core.Scalar_NoneType", reflect.TypeOf(p.GetScalar().Value).String()) - }) -} - -func TestMustMakeLiteral(t *testing.T) { - v := "hello" - l := MustMakeLiteral(v) - assert.NotNil(t, l.GetScalar()) - assert.Equal(t, v, l.GetScalar().GetPrimitive().GetStringValue()) -} - -func TestMakeBinaryLiteral(t *testing.T) { - s := MakeBinaryLiteral([]byte{'h'}) - assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) -} - -func TestMakeDefaultLiteralForType(t *testing.T) { - type args struct { - name string - ty core.SimpleType - tyName string - isPrimitive bool - } - tests := []args{ - {"None", core.SimpleType_NONE, "*core.Scalar_NoneType", false}, - {"Binary", core.SimpleType_BINARY, "*core.Scalar_Binary", false}, - {"Integer", core.SimpleType_INTEGER, "*core.Primitive_Integer", true}, - {"Float", core.SimpleType_FLOAT, "*core.Primitive_FloatValue", true}, - {"String", core.SimpleType_STRING, "*core.Primitive_StringValue", true}, - {"Boolean", core.SimpleType_BOOLEAN, "*core.Primitive_Boolean", true}, - {"Duration", core.SimpleType_DURATION, "*core.Primitive_Duration", true}, - {"Datetime", core.SimpleType_DATETIME, "*core.Primitive_Datetime", true}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{Simple: test.ty}}) - assert.NoError(t, err) - if test.isPrimitive { - assert.Equal(t, test.tyName, reflect.TypeOf(l.GetScalar().GetPrimitive().Value).String()) - } else { - assert.Equal(t, test.tyName, reflect.TypeOf(l.GetScalar().Value).String()) - } - }) - } - - t.Run("Blob", func(t *testing.T) { - l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Blob{}}) - assert.NoError(t, err) - assert.Equal(t, "*core.Scalar_Blob", reflect.TypeOf(l.GetScalar().Value).String()) - }) - - t.Run("Collection", func(t *testing.T) { - l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_CollectionType{CollectionType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}}) - assert.NoError(t, err) - assert.Equal(t, "*core.LiteralCollection", reflect.TypeOf(l.GetCollection()).String()) - }) - - t.Run("Map", func(t *testing.T) { - l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_MapValueType{MapValueType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}}) - assert.NoError(t, err) - assert.Equal(t, "*core.LiteralMap", reflect.TypeOf(l.GetMap()).String()) - }) - - t.Run("error", func(t *testing.T) { - l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_ERROR, - }}) - assert.NoError(t, err) - assert.NotNil(t, l.GetScalar().GetError()) - }) - - t.Run("struct", func(t *testing.T) { - l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_STRUCT, - }}) - assert.NoError(t, err) - assert.NotNil(t, l.GetScalar().GetGeneric()) - }) -} - -func TestMustMakeDefaultLiteralForType(t *testing.T) { - t.Run("error", func(t *testing.T) { - assert.Panics(t, func() { - MustMakeDefaultLiteralForType(nil) - }) - }) - - t.Run("Blob", func(t *testing.T) { - l := MustMakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Blob{}}) - assert.Equal(t, "*core.Scalar_Blob", reflect.TypeOf(l.GetScalar().Value).String()) - }) -} - -func TestMakePrimitiveForType(t *testing.T) { - n := time.Now() - type args struct { - t core.SimpleType - s string - } - tests := []struct { - name string - args args - want *core.Primitive - wantErr bool - }{ - {"error-type", args{core.SimpleType_NONE, "x"}, nil, true}, - - {"error-int", args{core.SimpleType_INTEGER, "x"}, nil, true}, - {"int", args{core.SimpleType_INTEGER, "1"}, MustMakePrimitive(1), false}, - - {"error-bool", args{core.SimpleType_BOOLEAN, "x"}, nil, true}, - {"bool", args{core.SimpleType_BOOLEAN, "true"}, MustMakePrimitive(true), false}, - - {"error-float", args{core.SimpleType_FLOAT, "x"}, nil, true}, - {"float", args{core.SimpleType_FLOAT, "3.1416"}, MustMakePrimitive(3.1416), false}, - - {"string", args{core.SimpleType_STRING, "string"}, MustMakePrimitive("string"), false}, - - {"error-dt", args{core.SimpleType_DATETIME, "x"}, nil, true}, - {"dt", args{core.SimpleType_DATETIME, n.Format(time.RFC3339Nano)}, MustMakePrimitive(n), false}, - - {"error-dur", args{core.SimpleType_DURATION, "x"}, nil, true}, - {"dur", args{core.SimpleType_DURATION, time.Hour.String()}, MustMakePrimitive(time.Hour), false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := MakePrimitiveForType(tt.args.t, tt.args.s) - if (err != nil) != tt.wantErr { - t.Errorf("MakePrimitiveForType() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("MakePrimitiveForType() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestMakeLiteralForSimpleType(t *testing.T) { - type args struct { - t core.SimpleType - s string - } - tests := []struct { - name string - args args - want *core.Literal - wantErr bool - }{ - {"error-int", args{core.SimpleType_INTEGER, "x"}, nil, true}, - {"int", args{core.SimpleType_INTEGER, "1"}, MustMakeLiteral(1), false}, - - {"error-struct", args{core.SimpleType_STRUCT, "x"}, nil, true}, - {"struct", args{core.SimpleType_STRUCT, `{"x": 1}`}, MustMakeLiteral(&structpb.Struct{Fields: map[string]*structpb.Value{"x": {Kind: &structpb.Value_NumberValue{NumberValue: 1}}}}), false}, - - {"bin", args{core.SimpleType_BINARY, "x"}, MustMakeLiteral([]byte("x")), false}, - - {"error", args{core.SimpleType_ERROR, "err"}, MustMakeLiteral(&core.Error{Message: "err"}), false}, - - {"none", args{core.SimpleType_NONE, "null"}, MustMakeLiteral(nil), false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := MakeLiteralForSimpleType(tt.args.t, tt.args.s) - if (err != nil) != tt.wantErr { - t.Errorf("MakeLiteralForSimpleType() error = %v, wantErr %v", err, tt.wantErr) - return - } - if diff := deep.Equal(tt.want, got); diff != nil { - t.Errorf("MakeLiteralForSimpleType() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestMakeLiteralForBlob(t *testing.T) { - type args struct { - path storage.DataReference - isDir bool - format string - } - tests := []struct { - name string - args args - want *core.Blob - }{ - {"simple-key", args{path: "/key", isDir: false, format: "xyz"}, &core.Blob{Uri: "/key", Metadata: &core.BlobMetadata{Type: &core.BlobType{Format: "xyz", Dimensionality: core.BlobType_SINGLE}}}}, - {"simple-dir", args{path: "/key", isDir: true, format: "xyz"}, &core.Blob{Uri: "/key", Metadata: &core.BlobMetadata{Type: &core.BlobType{Format: "xyz", Dimensionality: core.BlobType_MULTIPART}}}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := MakeLiteralForBlob(tt.args.path, tt.args.isDir, tt.args.format); !reflect.DeepEqual(got.GetScalar().GetBlob(), tt.want) { - t.Errorf("MakeLiteralForBlob() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go b/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go index 206eac0ead..9814c953f3 100644 --- a/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go +++ b/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go @@ -1,6 +1,7 @@ package sagemaker import ( + "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flytestdlib/promutils" "github.com/pkg/errors" @@ -153,10 +154,10 @@ func generateMockCustomTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTem Literals: map[string]*flyteIdlCore.Literal{ "train": generateMockBlobLiteral(trainBlobLoc), "validation": generateMockBlobLiteral(validationBlobLoc), - "hp_int": utils.MustMakeLiteral(1), - "hp_float": utils.MustMakeLiteral(1.5), - "hp_bool": utils.MustMakeLiteral(false), - "hp_string": utils.MustMakeLiteral("a"), + "hp_int": coreutils.MustMakeLiteral(1), + "hp_float": coreutils.MustMakeLiteral(1.5), + "hp_bool": coreutils.MustMakeLiteral(false), + "hp_string": coreutils.MustMakeLiteral("a"), }, }, nil) taskCtx.OnInputReader().Return(inputReader) @@ -237,7 +238,7 @@ func generateMockTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate, Literals: map[string]*flyteIdlCore.Literal{ "train": generateMockBlobLiteral(trainBlobLoc), "validation": generateMockBlobLiteral(validationBlobLoc), - "static_hyperparameters": utils.MakeGenericLiteral(shpStructObj), + "static_hyperparameters": coreutils.MakeGenericLiteral(shpStructObj), }, }, nil) taskCtx.OnInputReader().Return(inputReader) @@ -347,9 +348,9 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T Literals: map[string]*flyteIdlCore.Literal{ "train": generateMockBlobLiteral(trainBlobLoc), "validation": generateMockBlobLiteral(validationBlobLoc), - "static_hyperparameters": utils.MakeGenericLiteral(shpStructObj), - "hyperparameter_tuning_job_config": utils.MakeBinaryLiteral(hpoJobConfigByteArray), - "a": utils.MakeGenericLiteral(intParamRange), + "static_hyperparameters": coreutils.MakeGenericLiteral(shpStructObj), + "hyperparameter_tuning_job_config": coreutils.MakeBinaryLiteral(hpoJobConfigByteArray), + "a": coreutils.MakeGenericLiteral(intParamRange), }, }, nil) taskCtx.OnInputReader().Return(inputReader)