Skip to content

Commit

Permalink
Add parse tests and improve required flags logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Dec 27, 2024
1 parent 1239c35 commit c5cb6c4
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 36 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/internal/
internal/
tmp/
45 changes: 18 additions & 27 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,36 +102,27 @@ func Parse(root *Command, args []string) error {
return fmt.Errorf("command %q: %w", current.Name, err)
}

// Check required flags by inspecting the args string for their presence
if len(current.FlagsMetadata) > 0 {
var missingFlags []string
for _, flagMetadata := range current.FlagsMetadata {
if !flagMetadata.Required {
continue
}
// TODO(mf): we need to validate that the metadata flag is known to the flag set
flag := combinedFlags.Lookup(flagMetadata.Name)
if flag == nil {
return fmt.Errorf("command %q: internal error: required flag %q not found in flag set", current.Name, flagMetadata.Name)
}
// Look for the flag in the original args before any delimiter
found := false
for _, arg := range argsToParse {
// Match either -flag or --flag
if arg == "-"+flagMetadata.Name || arg == "--"+flagMetadata.Name ||
strings.HasPrefix(arg, "-"+flagMetadata.Name+"=") ||
strings.HasPrefix(arg, "--"+flagMetadata.Name+"=") {
found = true
break
// Check required flags by checking if they were actually set to non-default values
var missingFlags []string
for _, cmd := range commandChain {
if len(cmd.FlagsMetadata) > 0 {
for _, flagMetadata := range cmd.FlagsMetadata {
if !flagMetadata.Required {
continue
}
flag := combinedFlags.Lookup(flagMetadata.Name)
if flag == nil {
return fmt.Errorf("command %q: internal error: required flag %q not found in flag set", current.Name, flagMetadata.Name)
}
// Check if the flag was set by checking its actual value against default
if flag.Value.String() == flag.DefValue {
missingFlags = append(missingFlags, flagMetadata.Name)
}
}
if !found {
missingFlags = append(missingFlags, flagMetadata.Name)
}
}
if len(missingFlags) > 0 {
return fmt.Errorf("command %q: required flag(s) %q not set", current.Name, strings.Join(missingFlags, ", "))
}
}
if len(missingFlags) > 0 {
return fmt.Errorf("command %q: required flag(s) %q not set", current.Name, strings.Join(missingFlags, ", "))
}

// Skip past command names in remaining args from flag parsing
Expand Down
75 changes: 67 additions & 8 deletions parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ import (

// testState is a helper struct to hold the commands for testing
//
// root --verbose --version
// ├── add --dry-run
// └── nested --force
// └── sub --echo
// root --verbose --version
// ├── add --dry-run
// └── nested --force
// └── sub --echo
// └── hello --mandatory-flag
type testState struct {
add *Command
nested, sub *Command
root *Command
add *Command
nested, sub, hello *Command
root *Command
}

func newTestState() testState {
Expand All @@ -37,14 +38,27 @@ func newTestState() testState {
Flags: FlagsFunc(func(fset *flag.FlagSet) {
fset.String("echo", "", "echo the message")
}),
FlagsMetadata: []FlagMetadata{
{Name: "echo", Required: false}, // not required
},
Exec: exec,
}
hello := &Command{
Name: "hello",
Flags: FlagsFunc(func(fset *flag.FlagSet) {
fset.Bool("mandatory-flag", false, "mandatory flag")
}),
FlagsMetadata: []FlagMetadata{
{Name: "mandatory-flag", Required: true},
},
Exec: exec,
}
nested := &Command{
Name: "nested",
Flags: FlagsFunc(func(fset *flag.FlagSet) {
fset.Bool("force", false, "force the operation")
}),
SubCommands: []*Command{sub},
SubCommands: []*Command{sub, hello},
Exec: exec,
}
root := &Command{
Expand Down Expand Up @@ -290,4 +304,49 @@ func TestParse(t *testing.T) {
require.Error(t, err)
require.ErrorContains(t, err, `subcommand in path "todo nested" has no name`)
})
t.Run("required flag not set", func(t *testing.T) {
t.Parallel()
s := newTestState()

err := Parse(s.root, []string{"nested", "hello"})
require.Error(t, err)
// TODO(mf): this error message should have the full path to the command, e.g., "todo nested hello"
require.ErrorContains(t, err, `command "hello": required flag(s) "mandatory-flag" not set`)

// Correct type
err = Parse(s.root, []string{"nested", "hello", "--mandatory-flag", "true"})
require.NoError(t, err)
require.True(t, GetFlag[bool](s.root.selected.state, "mandatory-flag"))
// Incorrect type
err = Parse(s.root, []string{"nested", "hello", "--mandatory-flag=not-a-bool"})
require.Error(t, err)
require.ErrorContains(t, err, `command "hello": invalid boolean value "not-a-bool" for -mandatory-flag: parse error`)
})
t.Run("unknown required flag set by cli author", func(t *testing.T) {
t.Parallel()
cmd := &Command{
Name: "root",
FlagsMetadata: []FlagMetadata{
{Name: "some-other-flag", Required: true},
},
}
err := Parse(cmd, nil)
require.Error(t, err)
// TODO(mf): consider improving this error message so it's obvious that a "required" flag
// was set by the cli author but not registered in the flag set
require.ErrorContains(t, err, `command "root": internal error: required flag "some-other-flag" not found in flag set`)
})
t.Run("space in command name", func(t *testing.T) {
t.Parallel()
cmd := &Command{
Name: "root",
SubCommands: []*Command{
{Name: "sub command"},
},
}
err := Parse(cmd, nil)
require.Error(t, err)
// TODO(mf): consider improving this error message so it's a bit more user-friendly
require.ErrorContains(t, err, `command name "sub command" contains spaces`)
})
}

0 comments on commit c5cb6c4

Please sign in to comment.