Skip to content

Commit

Permalink
Support binding pflags to default variable fields (flyteorg#69)
Browse files Browse the repository at this point in the history
* Update CI post migration (flyteorg#68)

* Update CI post migration

* Migrate to github workflows

* length check

* debug

* Avoid installing pflags from repo

* Update deps

* update go action

* try import path

* update protos

* set checkout depth

* cleanup

* typo in master wf

Signed-off-by: Haytham Abuelfutuh <[email protected]>
Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Disable scoope publishing

Signed-off-by: Haytham Abuelfutuh <[email protected]>
Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Rename (flyteorg#73)

* first pass at updating references

* update module name

* update boilerplate mod

* remove version

* mor

* mockery

* maybe

* mor

* replace replace

* trying something

* other

Signed-off-by: Haytham Abuelfutuh <[email protected]>
Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Remove dependency on lyft/api (flyteorg#71)

Signed-off-by: Haytham Abuelfutuh <[email protected]>
Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Always try to create a bucket when loading a container (flyteorg#76)

Signed-off-by: Haytham Abuelfutuh <[email protected]>
Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Update create container if not exists logic (flyteorg#77)

Signed-off-by: Haytham Abuelfutuh <[email protected]>
Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Support binding pflags to default variable fields
Support Map Types now that viper does

Signed-off-by: Haytham Abuelfutuh <[email protected]>
Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Added some coverage

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Reverting to go 1.16

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Fixed unit test

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Added more coverage

Signed-off-by: Prafulla Mahindrakar <[email protected]>

Co-authored-by: brucearctor <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Co-authored-by: Katrina Rogan <[email protected]>
Co-authored-by: Prafulla Mahindrakar <[email protected]>
  • Loading branch information
5 people authored May 18, 2021
1 parent 99f4b20 commit fd1a616
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 242 deletions.
166 changes: 118 additions & 48 deletions cli/pflags/api/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ const (

// PFlagProviderGenerator parses and generates GetPFlagSet implementation to add PFlags for a given struct's fields.
type PFlagProviderGenerator struct {
pkg *types.Package
st *types.Named
defaultVar *types.Var
pkg *types.Package
st *types.Named
defaultVar *types.Var
shouldBindDefaultVar bool
}

// This list is restricted because that's the only kinds viper parses out, otherwise it assumes strings.
Expand All @@ -35,6 +36,7 @@ var allowedKinds = []types.Type{
types.Typ[types.Int64],
types.Typ[types.Bool],
types.Typ[types.String],
types.NewMap(types.Typ[types.String], types.Typ[types.String]),
}

type SliceOrArray interface {
Expand All @@ -49,8 +51,8 @@ func capitalize(s string) string {
return s
}

func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage, defaultValue string) (FieldInfo, error) {
strategy := SliceRaw
func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage, defaultValue string, bindDefaultVar bool) (FieldInfo, error) {
strategy := Raw
FlagMethodName := "StringSlice"
typ := types.NewSlice(types.Typ[types.String])
emptyDefaultValue := `[]string{}`
Expand All @@ -76,14 +78,59 @@ func buildFieldForSlice(ctx context.Context, t SliceOrArray, name, goName, usage
}

return FieldInfo{
Name: name,
GoName: goName,
Typ: typ,
FlagMethodName: FlagMethodName,
DefaultValue: defaultValue,
UsageString: usage,
TestValue: testValue,
TestStrategy: strategy,
Name: name,
GoName: goName,
Typ: typ,
FlagMethodName: FlagMethodName,
DefaultValue: defaultValue,
UsageString: usage,
TestValue: testValue,
TestStrategy: strategy,
ShouldBindDefault: bindDefaultVar,
}, nil
}

func buildFieldForMap(ctx context.Context, t *types.Map, name, goName, usage, defaultValue string, bindDefaultVar bool) (FieldInfo, error) {
strategy := Raw
FlagMethodName := "StringToString"
typ := types.NewMap(types.Typ[types.String], types.Typ[types.String])
emptyDefaultValue := `nil`
if k, ok := t.Key().(*types.Basic); !ok || k.Kind() != types.String {
logger.Infof(ctx, "Key of type [%v] is not a basic type. It must be json unmarshalable or generation will fail.", t.Elem())
} else if v, valueOk := t.Elem().(*types.Basic); !valueOk && !isJSONUnmarshaler(t.Elem()) {
return FieldInfo{},
fmt.Errorf("map of type [%v] is not supported. Only basic slices or slices of json-unmarshalable types are supported",
t.Elem().String())
} else {
logger.Infof(ctx, "Map[%v]%v is supported. using pflag maps.", k, t.Elem())
strategy = Raw
if valueOk {
FlagMethodName = fmt.Sprintf("StringTo%v", capitalize(v.Name()))
typ = types.NewMap(k, v)
emptyDefaultValue = fmt.Sprintf(`map[%v]%v{}`, k.Name(), v.Name())
} else {
// Value is not a basic type. Rely on json marshaling to unmarshal it
FlagMethodName = fmt.Sprintf("StringToString")
}
}

if len(defaultValue) == 0 {
defaultValue = emptyDefaultValue
}

testValue := `"a=1,b=2"`

return FieldInfo{
Name: name,
GoName: goName,
Typ: typ,
FlagMethodName: FlagMethodName,
DefaultValue: defaultValue,
UsageString: usage,
TestValue: testValue,
TestStrategy: strategy,
ShouldBindDefault: bindDefaultVar,
ShouldTestDefault: false,
}, nil
}

Expand Down Expand Up @@ -121,7 +168,7 @@ func appendAccessors(accessors ...string) string {
// met; encountered a basic type (e.g. string, int... etc.) or the field type implements UnmarshalJSON.
// If passed a non-empty defaultValueAccessor, it'll be used to fill in default values instead of any default value
// specified in pflag tag.
func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor, fieldPath string) ([]FieldInfo, error) {
func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValueAccessor, fieldPath string, bindDefaultVar bool) ([]FieldInfo, error) {
logger.Printf(ctx, "Finding all fields in [%v.%v.%v]",
typ.Obj().Pkg().Path(), typ.Obj().Pkg().Name(), typ.Obj().Name())

Expand Down Expand Up @@ -187,14 +234,15 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue
}

fields = append(fields, FieldInfo{
Name: tag.Name,
GoName: v.Name(),
Typ: t,
FlagMethodName: camelCase(t.String()),
DefaultValue: defaultValue,
UsageString: tag.Usage,
TestValue: `"1"`,
TestStrategy: JSON,
Name: tag.Name,
GoName: v.Name(),
Typ: t,
FlagMethodName: camelCase(t.String()),
DefaultValue: defaultValue,
UsageString: tag.Usage,
TestValue: `"1"`,
TestStrategy: JSON,
ShouldBindDefault: bindDefaultVar,
})
case *types.Named:
if _, isStruct := t.Underlying().(*types.Struct); !isStruct {
Expand All @@ -211,10 +259,14 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue
defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name())
if isStringer(t) {
defaultValue = defaultValue + ".String()"
} else {
} else if isJSONMarshaler(t) {
logger.Infof(ctx, "Field [%v] of type [%v] does not implement Stringer interface."+
" Will use %s.mustMarshalJSON() to get its default value.", defaultValueAccessor, v.Name(), t.String())
defaultValue = fmt.Sprintf("%s.mustMarshalJSON(%s)", defaultValueAccessor, defaultValue)
} else {
logger.Infof(ctx, "Field [%v] of type [%v] does not implement Stringer interface."+
" Will use %s.mustMarshalJSON() to get its default value.", defaultValueAccessor, v.Name(), t.String())
defaultValue = fmt.Sprintf("%s.mustJsonMarshal(%s)", defaultValueAccessor, defaultValue)
}
}

Expand All @@ -229,49 +281,66 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue
logger.Infof(logger.WithIndent(ctx, indent), "Type is json unmarhslalable.")

fields = append(fields, FieldInfo{
Name: tag.Name,
GoName: v.Name(),
Typ: types.Typ[types.String],
FlagMethodName: "String",
DefaultValue: defaultValue,
UsageString: tag.Usage,
TestValue: testValue,
TestStrategy: JSON,
Name: tag.Name,
GoName: v.Name(),
Typ: types.Typ[types.String],
FlagMethodName: "String",
DefaultValue: defaultValue,
UsageString: tag.Usage,
TestValue: testValue,
TestStrategy: JSON,
ShouldBindDefault: bindDefaultVar,
})
} else {
logger.Infof(ctx, "Traversing fields in type.")

nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, defaultValueAccessor, appendAccessors(fieldPath, v.Name()))
nested, err := discoverFieldsRecursive(logger.WithIndent(ctx, indent), t, defaultValueAccessor, appendAccessors(fieldPath, v.Name()), bindDefaultVar)
if err != nil {
return nil, err
}

for _, subField := range nested {
fields = append(fields, FieldInfo{
Name: fmt.Sprintf("%v.%v", tag.Name, subField.Name),
GoName: fmt.Sprintf("%v.%v", v.Name(), subField.GoName),
Typ: subField.Typ,
FlagMethodName: subField.FlagMethodName,
DefaultValue: subField.DefaultValue,
UsageString: subField.UsageString,
TestValue: subField.TestValue,
TestStrategy: subField.TestStrategy,
Name: fmt.Sprintf("%v.%v", tag.Name, subField.Name),
GoName: fmt.Sprintf("%v.%v", v.Name(), subField.GoName),
Typ: subField.Typ,
FlagMethodName: subField.FlagMethodName,
DefaultValue: subField.DefaultValue,
UsageString: subField.UsageString,
TestValue: subField.TestValue,
TestStrategy: subField.TestStrategy,
ShouldBindDefault: bindDefaultVar,
})
}
}
case *types.Slice:
logger.Infof(ctx, "[%v] is of a slice type with default value [%v].", tag.Name, tag.DefaultValue)
defaultValue := tag.DefaultValue

f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue)
f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, defaultValue, bindDefaultVar)
if err != nil {
return nil, err
}

fields = append(fields, f)
case *types.Array:
logger.Infof(ctx, "[%v] is of an array with default value [%v].", tag.Name, tag.DefaultValue)
logger.Infof(ctx, "[%v] is of an array type with default value [%v].", tag.Name, tag.DefaultValue)
defaultValue := tag.DefaultValue

f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, defaultValue, bindDefaultVar)
if err != nil {
return nil, err
}

fields = append(fields, f)
case *types.Map:
logger.Infof(ctx, "[%v] is of a map type with default value [%v].", tag.Name, tag.DefaultValue)
defaultValue := tag.DefaultValue
if len(defaultValueAccessor) > 0 {
defaultValue = appendAccessors(defaultValueAccessor, fieldPath, v.Name())
}

f, err := buildFieldForSlice(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, tag.DefaultValue)
f, err := buildFieldForMap(logger.WithIndent(ctx, indent), t, tag.Name, v.Name(), tag.Usage, defaultValue, bindDefaultVar)
if err != nil {
return nil, err
}
Expand All @@ -288,7 +357,7 @@ func discoverFieldsRecursive(ctx context.Context, typ *types.Named, defaultValue
// NewGenerator initializes a PFlagProviderGenerator for pflags files for targetTypeName struct under pkg. If pkg is not filled in,
// it's assumed to be current package (which is expected to be the common use case when invoking pflags from
// go:generate comments)
func NewGenerator(pkg, targetTypeName, defaultVariableName string) (*PFlagProviderGenerator, error) {
func NewGenerator(pkg, targetTypeName, defaultVariableName string, shouldBindDefaultVar bool) (*PFlagProviderGenerator, error) {
ctx := context.Background()
var err error

Expand Down Expand Up @@ -334,9 +403,10 @@ func NewGenerator(pkg, targetTypeName, defaultVariableName string) (*PFlagProvid
}

return &PFlagProviderGenerator{
st: st,
pkg: targetPackage,
defaultVar: defaultVar,
st: st,
pkg: targetPackage,
defaultVar: defaultVar,
shouldBindDefaultVar: shouldBindDefaultVar,
}, nil
}

Expand Down Expand Up @@ -369,7 +439,7 @@ func (g PFlagProviderGenerator) Generate(ctx context.Context) (PFlagProvider, er
defaultValueAccessor = g.defaultVar.Name()
}

fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor, "")
fields, err := discoverFieldsRecursive(ctx, g.st, defaultValueAccessor, "", g.shouldBindDefaultVar)
if err != nil {
return PFlagProvider{}, err
}
Expand Down
91 changes: 89 additions & 2 deletions cli/pflags/api/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package api
import (
"context"
"flag"
"fmt"
"go/token"
"go/types"
"io/ioutil"
"os"
"path/filepath"
Expand Down Expand Up @@ -45,11 +48,10 @@ func TestElemValueOrNil(t *testing.T) {
}

func TestNewGenerator(t *testing.T) {
g, err := NewGenerator("github.com/flyteorg/flytestdlib/cli/pflags/api", "TestType", "DefaultTestType")
g, err := NewGenerator("github.com/flyteorg/flytestdlib/cli/pflags/api", "TestType", "DefaultTestType", false)
if !assert.NoError(t, err) {
t.FailNow()
}

ctx := context.Background()
p, err := g.Generate(ctx)
if !assert.NoError(t, err) {
Expand Down Expand Up @@ -93,4 +95,89 @@ func TestNewGenerator(t *testing.T) {
goldenTestOutput, err := ioutil.ReadFile(filepath.Clean(goldenTestFilePath))
assert.NoError(t, err)
assert.Equal(t, string(goldenTestOutput), string(testBytes))
t.Run("empty package", func(t *testing.T) {
gen, err := NewGenerator("", "TestType", "DefaultTestType", false)
assert.Nil(t, err)
assert.NotNil(t, gen.GetTargetPackage())
})
}

func TestBuildFieldForMap(t *testing.T) {
t.Run("supported : StringToString", func(t *testing.T) {
ctx := context.Background()
key := types.Typ[types.String]
elem := types.Typ[types.String]
typesMap := types.NewMap(key, elem)
name := "m"
goName := "StringMap"
usage := "I'm a map of strings"
defaultValue := "DefaultValue"
fieldInfo, err := buildFieldForMap(ctx, typesMap, name, goName, usage, defaultValue, false)
assert.Nil(t, err)
assert.NotNil(t, fieldInfo)
assert.Equal(t, "StringToString", fieldInfo.FlagMethodName)
assert.Equal(t, defaultValue, fieldInfo.DefaultValue)
})
t.Run("unsupported : not a string type map", func(t *testing.T) {
ctx := context.Background()
key := types.Typ[types.Bool]
elem := types.Typ[types.Bool]
typesMap := types.NewMap(key, elem)
name := "m"
goName := "BoolMap"
usage := "I'm a map of bools"
defaultValue := ""
fieldInfo, err := buildFieldForMap(ctx, typesMap, name, goName, usage, defaultValue, false)
assert.Nil(t, err)
assert.NotNil(t, fieldInfo)
assert.Equal(t, "StringToString", fieldInfo.FlagMethodName)
assert.Equal(t, "nil", fieldInfo.DefaultValue)
})
t.Run("unsupported : elem not a basic type", func(t *testing.T) {
ctx := context.Background()
key := types.Typ[types.String]
elem := &types.Interface{}
typesMap := types.NewMap(key, elem)
name := "m"
goName := "InterfaceMap"
usage := "I'm a map of interface values"
defaultValue := ""
fieldInfo, err := buildFieldForMap(ctx, typesMap, name, goName, usage, defaultValue, false)
assert.NotNil(t, err)
assert.Equal(t, fmt.Errorf("map of type [interface{/* incomplete */}] is not supported."+
" Only basic slices or slices of json-unmarshalable types are supported"), err)
assert.NotNil(t, fieldInfo)
assert.Equal(t, "", fieldInfo.FlagMethodName)
assert.Equal(t, "", fieldInfo.DefaultValue)
})
t.Run("supported : StringToFloat64", func(t *testing.T) {
ctx := context.Background()
key := types.Typ[types.String]
elem := types.Typ[types.Float64]
typesMap := types.NewMap(key, elem)
name := "m"
goName := "Float64Map"
usage := "I'm a map of float64"
defaultValue := "DefaultValue"
fieldInfo, err := buildFieldForMap(ctx, typesMap, name, goName, usage, defaultValue, false)
assert.Nil(t, err)
assert.NotNil(t, fieldInfo)
assert.Equal(t, "StringToFloat64", fieldInfo.FlagMethodName)
assert.Equal(t, defaultValue, fieldInfo.DefaultValue)
})
}

func TestDiscoverFieldsRecursive(t *testing.T) {
t.Run("empty struct", func(t *testing.T) {
ctx := context.Background()
defaultValueAccessor := "defaultAccessor"
fieldPath := "field.Path"
pkg := types.NewPackage("p", "p")
n1 := types.NewTypeName(token.NoPos, pkg, "T1", nil)
namedTypes := types.NewNamed(n1, new(types.Struct), nil)
//namedTypes := types.NewNamed(n1, nil, nil)
fields, err := discoverFieldsRecursive(ctx, namedTypes, defaultValueAccessor, fieldPath, false)
assert.Nil(t, err)
assert.Equal(t, len(fields), 0)
})
}
1 change: 1 addition & 0 deletions cli/pflags/api/sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type TestType struct {
IgnoredMap map[string]string `json:"ignored-map" pflag:"-,"`
StorageConfig storage.Config `json:"storage"`
IntValue *int `json:"i"`
StringMap map[string]string `json:"m" pflag:",I'm a map of strings"`
}

type NestedType struct {
Expand Down
Loading

0 comments on commit fd1a616

Please sign in to comment.