Skip to content

Commit

Permalink
Expose Path method on *Command
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Jan 7, 2025
1 parent d60fd7c commit cf5befe
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 62 deletions.
17 changes: 12 additions & 5 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,21 @@ type Command struct {
state *State
}

func (c *Command) terminal() (*Command, *State) {
if c.state == nil || len(c.state.commandPath) == 0 {
return c, c.state
// Path returns the command chain from root to current command. It can only be called after the root
// command has been parsed and the command hierarchy has been established.
func (c *Command) Path() []*Command {
if c.state == nil {
return nil
}
return c.state.path
}

func (c *Command) terminal() *Command {
if c.state == nil || len(c.state.path) == 0 {
return c
}
// Get the last command in the path - this is our terminal command
terminalCmd := c.state.commandPath[len(c.state.commandPath)-1]
return terminalCmd, c.state
return c.state.path[len(c.state.path)-1]
}

// FlagMetadata holds additional metadata for a flag, such as whether it is required.
Expand Down
14 changes: 7 additions & 7 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ func Parse(root *Command, args []string) error {
// Initialize or update root state
if root.state == nil {
root.state = &State{
commandPath: []*Command{root},
path: []*Command{root},
}
} else {
// Reset command path but preserve other state
root.state.commandPath = []*Command{root}
root.state.path = []*Command{root}
}
// First split args at the -- delimiter if present
var argsToParse []string
Expand Down Expand Up @@ -90,7 +90,7 @@ func Parse(root *Command, args []string) error {
// Try to traverse to subcommand
if len(current.SubCommands) > 0 {
if sub := current.findSubCommand(arg); sub != nil {
root.state.commandPath = append(slices.Clone(root.state.commandPath), sub)
root.state.path = append(slices.Clone(root.state.path), sub)
if sub.Flags == nil {
sub.Flags = flag.NewFlagSet(sub.Name, flag.ContinueOnError)
}
Expand Down Expand Up @@ -133,7 +133,7 @@ func Parse(root *Command, args []string) error {

// Let ParseToEnd handle the flag parsing
if err := xflag.ParseToEnd(combinedFlags, argsToParse); err != nil {
return fmt.Errorf("command %q: %w", getCommandPath(root.state.commandPath), err)
return fmt.Errorf("command %q: %w", getCommandPath(root.state.path), err)
}

// Check required flags
Expand All @@ -146,7 +146,7 @@ func Parse(root *Command, args []string) error {
}
flag := combinedFlags.Lookup(flagMetadata.Name)
if flag == nil {
return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(root.state.commandPath), formatFlagName(flagMetadata.Name))
return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(root.state.path), formatFlagName(flagMetadata.Name))
}
if _, isBool := flag.Value.(interface{ IsBoolFlag() bool }); isBool {
isSet := false
Expand All @@ -170,7 +170,7 @@ func Parse(root *Command, args []string) error {
if len(missingFlags) > 1 {
msg += "s"
}
return fmt.Errorf("command %q: %s %q not set", getCommandPath(root.state.commandPath), msg, strings.Join(missingFlags, ", "))
return fmt.Errorf("command %q: %s %q not set", getCommandPath(root.state.path), msg, strings.Join(missingFlags, ", "))
}

// Skip past command names in remaining args
Expand Down Expand Up @@ -201,7 +201,7 @@ func Parse(root *Command, args []string) error {
root.state.Args = finalArgs

if current.Exec == nil {
return fmt.Errorf("command %q: no exec function defined", getCommandPath(root.state.commandPath))
return fmt.Errorf("command %q: no exec function defined", getCommandPath(root.state.path))
}
return nil
}
Expand Down
77 changes: 47 additions & 30 deletions parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,10 @@ func TestParse(t *testing.T) {

err := Parse(s.root, []string{"add", "item1"})
require.NoError(t, err)
require.NotNil(t, s.root.state)
require.NotEmpty(t, s.root.state.commandPath)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

require.Equal(t, s.add, cmd)
require.False(t, GetFlag[bool](state, "dry-run"))
require.False(t, GetFlag[bool](s.root.state, "dry-run"))
})
t.Run("unknown flag", func(t *testing.T) {
t.Parallel()
Expand All @@ -171,9 +170,10 @@ func TestParse(t *testing.T) {

err := Parse(s.root, []string{"add", "--dry-run", "item1"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

assert.Equal(t, s.add, cmd)
assert.True(t, GetFlag[bool](state, "dry-run"))
assert.True(t, GetFlag[bool](s.root.state, "dry-run"))
})
t.Run("help flag", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -221,65 +221,71 @@ func TestParse(t *testing.T) {

err := Parse(s.root, []string{"add", "--dry-run", "item1", "--verbose"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

assert.Equal(t, s.add, cmd)
assert.True(t, GetFlag[bool](state, "dry-run"))
assert.True(t, GetFlag[bool](state, "verbose"))
assert.True(t, GetFlag[bool](s.root.state, "dry-run"))
assert.True(t, GetFlag[bool](s.root.state, "verbose"))
})
t.Run("nested subcommand and root flag", func(t *testing.T) {
t.Parallel()
s := newTestState()

err := Parse(s.root, []string{"--verbose", "nested", "sub", "--echo", "hello"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

assert.Equal(t, s.sub, cmd)
assert.Equal(t, "hello", GetFlag[string](state, "echo"))
assert.True(t, GetFlag[bool](state, "verbose"))
assert.Equal(t, "hello", GetFlag[string](s.root.state, "echo"))
assert.True(t, GetFlag[bool](s.root.state, "verbose"))
})
t.Run("nested subcommand with mixed flags", func(t *testing.T) {
t.Parallel()
s := newTestState()

err := Parse(s.root, []string{"nested", "sub", "--echo", "hello", "--verbose"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

assert.Equal(t, s.sub, cmd)
assert.Equal(t, "hello", GetFlag[string](state, "echo"))
assert.True(t, GetFlag[bool](state, "verbose"))
assert.Equal(t, "hello", GetFlag[string](s.root.state, "echo"))
assert.True(t, GetFlag[bool](s.root.state, "verbose"))
})
t.Run("end of options delimiter", func(t *testing.T) {
t.Parallel()
s := newTestState()

err := Parse(s.root, []string{"--verbose", "--", "nested", "sub", "--echo", "hello"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

assert.Equal(t, s.root, cmd)
assert.Equal(t, []string{"nested", "sub", "--echo", "hello"}, state.Args)
assert.True(t, GetFlag[bool](state, "verbose"))
assert.Equal(t, []string{"nested", "sub", "--echo", "hello"}, s.root.state.Args)
assert.True(t, GetFlag[bool](s.root.state, "verbose"))
})
t.Run("flags and args", func(t *testing.T) {
t.Parallel()
s := newTestState()

err := Parse(s.root, []string{"add", "item1", "--dry-run", "item2"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

assert.Equal(t, s.add, cmd)
assert.True(t, GetFlag[bool](state, "dry-run"))
assert.Equal(t, []string{"item1", "item2"}, state.Args)
assert.True(t, GetFlag[bool](s.root.state, "dry-run"))
assert.Equal(t, []string{"item1", "item2"}, s.root.state.Args)
})
t.Run("nested subcommand with flags and args", func(t *testing.T) {
t.Parallel()
s := newTestState()

err := Parse(s.root, []string{"nested", "sub", "--echo", "hello", "world"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

assert.Equal(t, s.sub, cmd)
assert.Equal(t, "hello", GetFlag[string](state, "echo"))
assert.Equal(t, []string{"world"}, state.Args)
assert.Equal(t, "hello", GetFlag[string](s.root.state, "echo"))
assert.Equal(t, []string{"world"}, s.root.state.Args)
})
t.Run("subcommand flags not available in parent", func(t *testing.T) {
t.Parallel()
Expand All @@ -295,9 +301,10 @@ func TestParse(t *testing.T) {

err := Parse(s.root, []string{"nested", "sub", "--force"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

assert.Equal(t, s.sub, cmd)
assert.True(t, GetFlag[bool](state, "force"))
assert.True(t, GetFlag[bool](s.root.state, "force"))
})
t.Run("unrelated subcommand flags not inherited in other subcommands", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -329,18 +336,19 @@ func TestParse(t *testing.T) {
s := newTestState()
err := Parse(s.root, []string{"nested", "hello", "--mandatory-flag=true", "--another-mandatory-flag", "some-value"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := getCommand(t, s.root)

assert.Equal(t, s.hello, cmd)
require.True(t, GetFlag[bool](state, "mandatory-flag"))
require.True(t, GetFlag[bool](s.root.state, "mandatory-flag"))
}
{
// Correct type - false
s := newTestState()
err := Parse(s.root, []string{"nested", "hello", "--mandatory-flag=false", "--another-mandatory-flag=some-value"})
require.NoError(t, err)
cmd, state := s.root.terminal()
cmd := s.root.terminal()
assert.Equal(t, s.hello, cmd)
require.False(t, GetFlag[bool](state, "mandatory-flag"))
require.False(t, GetFlag[bool](s.root.state, "mandatory-flag"))
}
{
// Incorrect type
Expand Down Expand Up @@ -377,3 +385,12 @@ func TestParse(t *testing.T) {
require.ErrorContains(t, err, `failed to parse: command ["root", "sub command"]: must contain only letters, no spaces or special characters`)
})
}

func getCommand(t *testing.T, c *Command) *Command {
require.NotNil(t, c)
require.NotNil(t, c.state)
require.NotEmpty(t, c.state.path)
terminal := c.terminal()
require.NotNil(t, terminal)
return terminal
}
5 changes: 2 additions & 3 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@ func Run(ctx context.Context, root *Command, options *RunOptions) error {
if root == nil {
return errors.New("root command is nil")
}
if root.state == nil || len(root.state.commandPath) == 0 {
if root.state == nil || len(root.state.path) == 0 {
return errors.New("command has not been parsed")
}
options = checkAndSetRunOptions(options)
updateState(root.state, options)

terminal, state := root.terminal()
return terminal.Exec(ctx, state)
return root.terminal().Exec(ctx, root.state)
}

func updateState(s *State, opt *RunOptions) {
Expand Down
12 changes: 7 additions & 5 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ type State struct {
Stdin io.Reader
Stdout, Stderr io.Writer

commandPath []*Command
// path is the command hierarchy from the root command to the current command. The root command
// is the first element in the path, and the terminal command is the last element.
path []*Command
}

// GetFlag retrieves a flag value by name from the command hierarchy. It first checks the current
Expand All @@ -31,8 +33,8 @@ type State struct {
// path := GetFlag[string](state, "path")
func GetFlag[T any](s *State, name string) T {
// Try to find the flag in each command's flag set, starting from the current command
for i := len(s.commandPath) - 1; i >= 0; i-- {
cmd := s.commandPath[i]
for i := len(s.path) - 1; i >= 0; i-- {
cmd := s.path[i]
if cmd.Flags == nil {
continue
}
Expand All @@ -45,7 +47,7 @@ func GetFlag[T any](s *State, name string) T {
}
err := fmt.Errorf("type mismatch for flag %q in command %q: registered %T, requested %T",
formatFlagName(name),
getCommandPath(s.commandPath),
getCommandPath(s.path),
value,
*new(T),
)
Expand All @@ -58,7 +60,7 @@ func GetFlag[T any](s *State, name string) T {
// If flag not found anywhere in hierarchy, panic with helpful message
err := fmt.Errorf("flag %q not found in command %q flag set",
formatFlagName(name),
getCommandPath(s.commandPath),
getCommandPath(s.path),
)
panic(err)
}
4 changes: 2 additions & 2 deletions state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestGetFlag(t *testing.T) {
Flags: flag.NewFlagSet("root", flag.ContinueOnError),
}
state := &State{
commandPath: []*Command{cmd},
path: []*Command{cmd},
}
defer func() {
r := recover()
Expand All @@ -35,7 +35,7 @@ func TestGetFlag(t *testing.T) {
Flags: FlagsFunc(func(f *flag.FlagSet) { f.String("version", "1.0.0", "show version") }),
}
state := &State{
commandPath: []*Command{cmd},
path: []*Command{cmd},
}
defer func() {
r := recover()
Expand Down
23 changes: 13 additions & 10 deletions usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ import (
"github.com/mfridman/cli/pkg/textutil"
)

func DefaultUsage(c *Command) string {
if c == nil {
// DefaultUsage returns the default usage string for the command hierarchy. It is used when the
// command does not provide a custom usage function. The usage string includes the command's short
// help, usage pattern, available subcommands, and flags.
func DefaultUsage(root *Command) string {
if root == nil {
return ""
}

// Get terminal command from state
terminalCmd, _ := c.terminal()
terminalCmd := root.terminal()

var b strings.Builder

Expand All @@ -34,8 +37,8 @@ func DefaultUsage(c *Command) string {
b.WriteString(" " + terminalCmd.Usage + "\n")
} else {
usage := terminalCmd.Name
if c.state != nil && len(c.state.commandPath) > 0 {
usage = getCommandPath(c.state.commandPath)
if root.state != nil && len(root.state.path) > 0 {
usage = getCommandPath(root.state.path)
}
if terminalCmd.Flags != nil {
usage += " [flags]"
Expand Down Expand Up @@ -83,12 +86,12 @@ func DefaultUsage(c *Command) string {
}

var flags []flagInfo
if c.state != nil && len(c.state.commandPath) > 0 {
for i, cmd := range c.state.commandPath {
if root.state != nil && len(root.state.path) > 0 {
for i, cmd := range root.state.path {
if cmd.Flags == nil {
continue
}
isGlobal := i < len(c.state.commandPath)-1
isGlobal := i < len(root.state.path)-1
cmd.Flags.VisitAll(func(f *flag.Flag) {
flags = append(flags, flagInfo{
name: "-" + f.Name,
Expand Down Expand Up @@ -137,8 +140,8 @@ func DefaultUsage(c *Command) string {

if len(terminalCmd.SubCommands) > 0 {
cmdName := terminalCmd.Name
if c.state != nil && len(c.state.commandPath) > 0 {
cmdName = getCommandPath(c.state.commandPath)
if root.state != nil && len(root.state.path) > 0 {
cmdName = getCommandPath(root.state.path)
}
fmt.Fprintf(&b, "Use \"%s [command] --help\" for more information about a command.\n", cmdName)
}
Expand Down

0 comments on commit cf5befe

Please sign in to comment.