Skip to content

Commit

Permalink
Add FlagConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
dearchap committed Nov 13, 2022
1 parent e926dde commit 30261ac
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 67 deletions.
1 change: 1 addition & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2654,6 +2654,7 @@ func TestFlagAction(t *testing.T) {
stringFlag := &StringFlag{
Name: "f_string",
Action: func(c *Context, v string) error {
t.Log("in adction")
if v == "" {
return fmt.Errorf("empty string")
}
Expand Down
4 changes: 2 additions & 2 deletions flag_bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type boolValue struct {
count *int
}

func (i boolValue) Create(val bool, p *bool /*, count *int*/) flag.Value {
func (i boolValue) Create(val bool, p *bool, c FlagConfig) flag.Value {
*p = val
return &boolValue{
destination: p,
Expand Down Expand Up @@ -57,7 +57,7 @@ func (b *boolValue) Count() int {
return 0
}

type BoolFlag = flagImpl[bool, boolValue]
type BoolFlag = FlagBase[bool, boolValue]

func (cCtx *Context) Bool(name string) bool {
if v, ok := cCtx.Value(name).(bool); ok {
Expand Down
4 changes: 2 additions & 2 deletions flag_duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
// -- time.Duration Value
type durationValue time.Duration

func (i durationValue) Create(val time.Duration, p *time.Duration) flag.Value {
func (i durationValue) Create(val time.Duration, p *time.Duration, c FlagConfig) flag.Value {
*p = val
return (*durationValue)(p)
}
Expand All @@ -26,7 +26,7 @@ func (d *durationValue) Get() any { return time.Duration(*d) }

func (d *durationValue) String() string { return (*time.Duration)(d).String() }

type DurationFlag = flagImpl[time.Duration, durationValue]
type DurationFlag = FlagBase[time.Duration, durationValue]

func (cCtx *Context) Duration(name string) time.Duration {
if v, ok := cCtx.Value(name).(time.Duration); ok {
Expand Down
4 changes: 2 additions & 2 deletions flag_float64.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
// -- float64 Value
type float64Value float64

func (f float64Value) Create(val float64, p *float64) flag.Value {
func (f float64Value) Create(val float64, p *float64, c FlagConfig) flag.Value {
*p = val
return (*float64Value)(p)
}
Expand All @@ -26,7 +26,7 @@ func (f *float64Value) Get() any { return float64(*f) }

func (f *float64Value) String() string { return strconv.FormatFloat(float64(*f), 'g', -1, 64) }

type Float64Flag = flagImpl[float64, float64Value]
type Float64Flag = FlagBase[float64, float64Value]

// Int looks up the value of a local IntFlag, returns
// 0 if not found
Expand Down
47 changes: 28 additions & 19 deletions flag_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ import (
"fmt"
)

type FlagConfig interface {
IntBase() int
}

type ValueCreator[T any] interface {
Create(t T, d *T) flag.Value
Create(T, *T, FlagConfig) flag.Value
}

// Float64Flag is a flag with type float64
type flagImpl[T any, F ValueCreator[T]] struct {
type FlagBase[T any, VC ValueCreator[T]] struct {
Name string

Category string
Expand All @@ -32,21 +36,26 @@ type flagImpl[T any, F ValueCreator[T]] struct {
Base int
Count *int

creator F
Action func(*Context, T) error

creator VC
value flag.Value
Action func(*Context, T) error
}

func (f *FlagBase[T, V]) IntBase() int {
return f.Base
}

// GetValue returns the flags value as string representation and an empty
// string if the flag takes no value at all.
func (f *flagImpl[T, V]) GetValue() string {
func (f *FlagBase[T, V]) GetValue() string {
return fmt.Sprintf("%v", f.Value)
}

// Apply populates the flag given the flag set and environment
func (f *flagImpl[T, V]) Apply(set *flag.FlagSet) error {
func (f *FlagBase[T, V]) Apply(set *flag.FlagSet) error {
if f.Destination == nil {
f.value = f.creator.Create(f.Value, new(T))
f.value = f.creator.Create(f.Value, new(T), f)
}

if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
Expand All @@ -61,7 +70,7 @@ func (f *flagImpl[T, V]) Apply(set *flag.FlagSet) error {

for _, name := range f.Names() {
if f.Destination != nil {
f.value = f.creator.Create(f.Value, f.Destination)
f.value = f.creator.Create(f.Value, f.Destination, f)
set.Var(f.value, name, f.Usage)
continue
}
Expand All @@ -72,60 +81,60 @@ func (f *flagImpl[T, V]) Apply(set *flag.FlagSet) error {
}

// String returns a readable representation of this value (for usage defaults)
func (f *flagImpl[T, V]) String() string {
func (f *FlagBase[T, V]) String() string {
return FlagStringer(f)
}

// IsSet returns whether or not the flag has been set through env or file
func (f *flagImpl[T, V]) IsSet() bool {
func (f *FlagBase[T, V]) IsSet() bool {
return f.HasBeenSet
}

// Names returns the names of the flag
func (f *flagImpl[T, V]) Names() []string {
func (f *FlagBase[T, V]) Names() []string {
return FlagNames(f.Name, f.Aliases)
}

// IsRequired returns whether or not the flag is required
func (f *flagImpl[T, V]) IsRequired() bool {
func (f *FlagBase[T, V]) IsRequired() bool {
return f.Required
}

// IsVisible returns true if the flag is not hidden, otherwise false
func (f *flagImpl[T, V]) IsVisible() bool {
func (f *FlagBase[T, V]) IsVisible() bool {
return !f.Hidden
}

// GetCategory returns the category of the flag
func (f *flagImpl[T, V]) GetCategory() string {
func (f *FlagBase[T, V]) GetCategory() string {
return f.Category
}

// GetUsage returns the usage string for the flag
func (f *flagImpl[T, V]) GetUsage() string {
func (f *FlagBase[T, V]) GetUsage() string {
return f.Usage
}

// GetEnvVars returns the env vars for this flag
func (f *flagImpl[T, V]) GetEnvVars() []string {
func (f *FlagBase[T, V]) GetEnvVars() []string {
return f.EnvVars
}

// TakesValue returns true if the flag takes a value, otherwise false
func (f *flagImpl[T, V]) TakesValue() bool {
func (f *FlagBase[T, V]) TakesValue() bool {
return "Float64Flag" != "BoolFlag"
}

// GetDefaultText returns the default text for this flag
func (f *flagImpl[T, V]) GetDefaultText() string {
func (f *FlagBase[T, V]) GetDefaultText() string {
if f.DefaultText != "" {
return f.DefaultText
}
return f.GetValue()
}

// Get returns the flag’s value in the given Context.
func (f *flagImpl[T, V]) Get(ctx *Context) T {
func (f *FlagBase[T, V]) Get(ctx *Context) T {
if v, ok := ctx.Value(f.Name).(T); ok {
return v
}
Expand Down
22 changes: 14 additions & 8 deletions flag_int.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,33 @@ import (
)

// -- int Value
type intValue int
type intValue struct {
val *int
base int
}

func (i intValue) Create(val int, p *int) flag.Value {
func (i intValue) Create(val int, p *int, c FlagConfig) flag.Value {
*p = val
return (*intValue)(p)
return &intValue{
val: p,
base: c.IntBase(),
}
}

func (i *intValue) Set(s string) error {
v, err := strconv.ParseInt(s, 0, strconv.IntSize)
v, err := strconv.ParseInt(s, i.base, strconv.IntSize)
if err != nil {
return err
}
*i = intValue(v)
*i.val = int(v)
return err
}

func (i *intValue) Get() any { return int(*i) }
func (i *intValue) Get() any { return int(*i.val) }

func (i *intValue) String() string { return strconv.Itoa(int(*i)) }
func (i *intValue) String() string { return strconv.Itoa(int(*i.val)) }

type IntFlag = flagImpl[int, intValue]
type IntFlag = FlagBase[int, intValue]

// Int looks up the value of a local IntFlag, returns
// 0 if not found
Expand Down
20 changes: 13 additions & 7 deletions flag_int64.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,33 @@ import (
)

// -- int64 Value
type int64Value int64
type int64Value struct {
val *int64
base int
}

func (i int64Value) Create(val int64, p *int64) flag.Value {
func (i int64Value) Create(val int64, p *int64, c FlagConfig) flag.Value {
*p = val
return (*int64Value)(p)
return &int64Value{
val: p,
base: c.IntBase(),
}
}

func (i *int64Value) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 64)
if err != nil {
return err
}
*i = int64Value(v)
*i.val = v
return err
}

func (i *int64Value) Get() any { return int64(*i) }
func (i *int64Value) Get() any { return int64(*i.val) }

func (i *int64Value) String() string { return strconv.FormatInt(int64(*i), 10) }
func (i *int64Value) String() string { return strconv.FormatInt(int64(*i.val), 10) }

type Int64Flag = flagImpl[int64, int64Value]
type Int64Flag = FlagBase[int64, int64Value]

// Int64 looks up the value of a local Int64Flag, returns
// 0 if not found
Expand Down
4 changes: 2 additions & 2 deletions flag_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
// -- string Value
type stringValue string

func (i stringValue) Create(val string, p *string) flag.Value {
func (i stringValue) Create(val string, p *string, c FlagConfig) flag.Value {
*p = val
return (*stringValue)(p)
}
Expand All @@ -21,7 +21,7 @@ func (s *stringValue) Get() any { return string(*s) }

func (s *stringValue) String() string { return string(*s) }

type StringFlag = flagImpl[string, stringValue]
type StringFlag = FlagBase[string, stringValue]

func (cCtx *Context) String(name string) string {
if v, ok := cCtx.Value(name).(string); ok {
Expand Down
22 changes: 14 additions & 8 deletions flag_uint.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,33 @@ import (
)

// -- uint Value
type uintValue uint
type uintValue struct {
val *uint
base int
}

func (i uintValue) Create(val uint, p *uint) flag.Value {
func (i uintValue) Create(val uint, p *uint, c FlagConfig) flag.Value {
*p = val
return (*uintValue)(p)
return &uintValue{
val: p,
base: c.IntBase(),
}
}

func (i *uintValue) Set(s string) error {
v, err := strconv.ParseUint(s, 0, strconv.IntSize)
v, err := strconv.ParseUint(s, i.base, strconv.IntSize)
if err != nil {
return err
}
*i = uintValue(v)
*i.val = uint(v)
return err
}

func (i *uintValue) Get() any { return uint(*i) }
func (i *uintValue) Get() any { return uint(*i.val) }

func (i *uintValue) String() string { return strconv.FormatUint(uint64(*i), 10) }
func (i *uintValue) String() string { return strconv.FormatUint(uint64(*i.val), 10) }

type UintFlag = flagImpl[uint, uintValue]
type UintFlag = FlagBase[uint, uintValue]

// Int looks up the value of a local IntFlag, returns
// 0 if not found
Expand Down
22 changes: 14 additions & 8 deletions flag_uint64.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,33 @@ import (
)

// -- uint64 Value
type uint64Value uint64
type uint64Value struct {
val *uint64
base int
}

func (i uint64Value) Create(val uint64, p *uint64) flag.Value {
func (i uint64Value) Create(val uint64, p *uint64, c FlagConfig) flag.Value {
*p = val
return (*uint64Value)(p)
return &uint64Value{
val: p,
base: c.IntBase(),
}
}

func (i *uint64Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 64)
v, err := strconv.ParseUint(s, i.base, 64)
if err != nil {
return err
}
*i = uint64Value(v)
*i.val = v
return err
}

func (i *uint64Value) Get() any { return uint64(*i) }
func (i *uint64Value) Get() any { return uint64(*i.val) }

func (i *uint64Value) String() string { return strconv.FormatUint(uint64(*i), 10) }
func (i *uint64Value) String() string { return strconv.FormatUint(uint64(*i.val), 10) }

type Uint64Flag = flagImpl[uint64, uint64Value]
type Uint64Flag = FlagBase[uint64, uint64Value]

// Int64 looks up the value of a local Int64Flag, returns
// 0 if not found
Expand Down
Loading

0 comments on commit 30261ac

Please sign in to comment.