Skip to content

Commit

Permalink
Tidy up Run behavior and GetFlag error states
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Jan 7, 2025
1 parent 86204ce commit d60fd7c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
13 changes: 12 additions & 1 deletion run.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,27 @@ package cli
import (
"context"
"errors"
"fmt"
"io"
"os"
)

// ParseAndRun parses the command hierarchy and runs the command. A convenience function that
// combines [Parse] and [Run] into a single call. See [Parse] and [Run] for more details.
func ParseAndRun(ctx context.Context, root *Command, args []string, options *RunOptions) error {
func ParseAndRun(ctx context.Context, root *Command, args []string, options *RunOptions) (retErr error) {
if err := Parse(root, args); err != nil {
return err
}
defer func() {
if r := recover(); r != nil {
switch err := r.(type) {
case error:
retErr = fmt.Errorf("internal: %v", err)
default:
retErr = fmt.Errorf("recovered: %v", r)
}
}
}()
return Run(ctx, root, options)
}

Expand Down
27 changes: 12 additions & 15 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ import (
"io"
)

// State represents the shared state for a command execution. It maintains a hierarchical structure
// that allows child commands to access global flags defined in parent commands. Use [GetFlag] to
// retrieve flag values by name.
// State holds command information during Exec function execution, allowing child commands to access
// parent flags. Use [GetFlag] to get flag values across the command hierarchy.
type State struct {
// Args contains the remaining arguments after flag parsing.
Args []string
Expand All @@ -20,18 +19,16 @@ type State struct {
commandPath []*Command
}

// GetFlag retrieves a flag value by name, with type inference. It traverses up the state hierarchy
// to find the flag, allowing access to parent command flags. Example usage:
// GetFlag retrieves a flag value by name from the command hierarchy. It first checks the current
// command's flags, then walks up through parent commands.
//
// If the flag doesn't exist or if the type doesn't match the requested type T an error will be
// raised in the Run function. This is an internal error and should never happen in normal usage.
// This ensures flag-related programming errors are caught early during development.
//
// verbose := GetFlag[bool](state, "verbose")
// count := GetFlag[int](state, "count")
// path := GetFlag[string](state, "path")
//
// If the flag isn't known, or is the wrong type, it panics with a detailed error message.
//
// Why panic? Because if a flag is missing, it's likely a programming error or a missing flag
// definition, and it's better to fail LOUD and EARLY than to silently ignore the issue and cause
// unexpected behavior.
func GetFlag[T any](s *State, name string) T {
// Try to find the flag in each command's flag set, starting from the current command
for i := len(s.commandPath) - 1; i >= 0; i-- {
Expand All @@ -46,22 +43,22 @@ func GetFlag[T any](s *State, name string) T {
if v, ok := value.(T); ok {
return v
}
msg := fmt.Sprintf("internal error: type mismatch for flag %q in command %q: registered %T, requested %T",
err := fmt.Errorf("type mismatch for flag %q in command %q: registered %T, requested %T",
formatFlagName(name),
getCommandPath(s.commandPath),
value,
*new(T),
)
// Flag exists but type doesn't match - this is an internal error
panic(msg)
panic(err)
}
}
}

// If flag not found anywhere in hierarchy, panic with helpful message
msg := fmt.Sprintf("internal error: flag %q not found in %q flag set",
err := fmt.Errorf("flag %q not found in command %q flag set",
formatFlagName(name),
getCommandPath(s.commandPath),
)
panic(msg)
panic(err)
}
8 changes: 6 additions & 2 deletions state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ func TestGetFlag(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
assert.Equal(t, `internal error: flag "-version" not found in "root" flag set`, r)
err, ok := r.(error)
require.True(t, ok)
assert.ErrorContains(t, err, `flag "-version" not found in command "root" flag set`)
}()
// Panic because author tried to access a flag that doesn't exist in any of the commands
_ = GetFlag[string](state, "version")
Expand All @@ -38,7 +40,9 @@ func TestGetFlag(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
assert.Equal(t, `internal error: type mismatch for flag "-version" in command "root": registered string, requested int`, r)
err, ok := r.(error)
require.True(t, ok)
assert.ErrorContains(t, err, `type mismatch for flag "-version" in command "root": registered string, requested int`)
}()
// Panic because author tried to access a registered flag with the wrong type
_ = GetFlag[int](state, "version")
Expand Down

0 comments on commit d60fd7c

Please sign in to comment.