diff --git a/cmd/phlare/main.go b/cmd/phlare/main.go index 89f15589c..2cf7d6467 100644 --- a/cmd/phlare/main.go +++ b/cmd/phlare/main.go @@ -5,25 +5,96 @@ import ( "flag" "fmt" "os" + "sort" + + "github.com/grafana/dskit/flagext" + "github.com/prometheus/common/version" "github.com/grafana/phlare/pkg/cfg" "github.com/grafana/phlare/pkg/phlare" + "github.com/grafana/phlare/pkg/usage" _ "github.com/grafana/phlare/pkg/util/build" ) +type mainFlags struct { + phlare.Config + + PrintVersion bool + PrintModules bool + PrintHelp bool + PrintHelpAll bool +} + +func (mf *mainFlags) Clone() flagext.Registerer { + return func(mf mainFlags) *mainFlags { + return &mf + }(*mf) +} + +func (mf *mainFlags) PhlareConfig() *phlare.Config { + return &mf.Config +} + +func (mf *mainFlags) RegisterFlags(fs *flag.FlagSet) { + mf.Config.RegisterFlags(fs) + fs.BoolVar(&mf.PrintVersion, "version", false, "Show the version of phlare and exit") + fs.BoolVar(&mf.PrintModules, "modules", false, "List available modules that can be used as target and exit.") + fs.BoolVar(&mf.PrintHelp, "h", false, "Print basic help.") + fs.BoolVar(&mf.PrintHelp, "help", false, "Print basic help.") + fs.BoolVar(&mf.PrintHelpAll, "help-all", false, "Print help, also including advanced and experimental parameters.") +} + func main() { - var config phlare.Config - if err := cfg.DynamicUnmarshal(&config, os.Args[1:], flag.CommandLine); err != nil { + var ( + flags mainFlags + ) + + if err := cfg.DynamicUnmarshal(&flags, os.Args[1:], flag.CommandLine); err != nil { fmt.Fprintf(os.Stderr, "failed parsing config: %v\n", err) os.Exit(1) } - f, err := phlare.New(config) + f, err := phlare.New(flags.Config) if err != nil { fmt.Fprintf(os.Stderr, "failed creating phlare: %v\n", err) os.Exit(1) } + if flags.PrintVersion { + fmt.Println(version.Print("phlare")) + return + } + + if flags.PrintModules { + allDeps := f.ModuleManager.DependenciesForModule(phlare.All) + + for _, m := range f.ModuleManager.UserVisibleModuleNames() { + ix := sort.SearchStrings(allDeps, m) + included := ix < len(allDeps) && allDeps[ix] == m + + if included { + fmt.Fprintln(os.Stdout, m, "*") + } else { + fmt.Fprintln(os.Stdout, m) + } + } + + fmt.Fprintln(os.Stdout) + fmt.Fprintln(os.Stdout, "Modules marked with * are included in target All.") + return + } + + if flags.PrintHelp || flags.PrintHelpAll { + // Print available parameters to stdout, so that users can grep/less them easily. + flag.CommandLine.SetOutput(os.Stdout) + if err := usage.Usage(flags.PrintHelpAll, &flags); err != nil { + fmt.Fprintf(os.Stderr, "error printing usage: %s\n", err) + os.Exit(1) + } + + return + } + err = f.Run() if err != nil { fmt.Fprintf(os.Stderr, "failed running phlare: %v\n", err) diff --git a/cmd/phlare/main_test.go b/cmd/phlare/main_test.go new file mode 100644 index 000000000..6f09f5e73 --- /dev/null +++ b/cmd/phlare/main_test.go @@ -0,0 +1,103 @@ +package main + +import ( + "flag" + "os" + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/grafana/phlare/pkg/test" +) + +func TestFlagParsing(t *testing.T) { + for name, tc := range map[string]struct { + arguments []string + stdoutMessage string // string that must be included in stdout + stderrMessage string // string that must be included in stderr + stdoutExcluded string // string that must NOT be included in stdout + stderrExcluded string // string that must NOT be included in stderr + }{ + "help-short": { + arguments: []string{"-h"}, + stdoutMessage: "Usage of", // Usage must be on stdout, not stderr. + stderrExcluded: "Usage of", + }, + "help": { + arguments: []string{"-help"}, + stdoutMessage: "Usage of", // Usage must be on stdout, not stderr. + stderrExcluded: "Usage of", + }, + "help-all": { + arguments: []string{"-help-all"}, + stdoutMessage: "Usage of", // Usage must be on stdout, not stderr. + stderrExcluded: "Usage of", + }, + "user visible module listing": { + arguments: []string{"-modules"}, + stdoutMessage: "ingester *\n", + stderrExcluded: "ingester\n", + }, + "version": { + arguments: []string{"-version"}, + stdoutMessage: "phlare, version", + stderrExcluded: "phlare, version", + }, + } { + t.Run(name, func(t *testing.T) { + _ = os.Setenv("TARGET", "ingester") + oldDefaultRegistry := prometheus.DefaultRegisterer + defer func() { + prometheus.DefaultRegisterer = oldDefaultRegistry + }() + // We need to reset the default registry to avoid + // "duplicate metrics collector registration attempted" errors. + prometheus.DefaultRegisterer = prometheus.NewRegistry() + testSingle(t, tc.arguments, tc.stdoutMessage, tc.stderrMessage, tc.stdoutExcluded, tc.stderrExcluded) + }) + } +} + +func testSingle(t *testing.T, arguments []string, stdoutMessage, stderrMessage, stdoutExcluded, stderrExcluded string) { + t.Helper() + oldArgs, oldStdout, oldStderr := os.Args, os.Stdout, os.Stderr + restored := false + restoreIfNeeded := func() { + if restored { + return + } + os.Stdout = oldStdout + os.Stderr = oldStderr + os.Args = oldArgs + restored = true + } + defer restoreIfNeeded() + + arguments = append([]string{"./phlare"}, arguments...) + + os.Args = arguments + co := test.CaptureOutput(t) + + // reset default flags + flag.CommandLine = flag.NewFlagSet(arguments[0], flag.ExitOnError) + + main() + + stdout, stderr := co.Done() + + // Restore stdout and stderr before reporting errors to make them visible. + restoreIfNeeded() + if !strings.Contains(stdout, stdoutMessage) { + t.Errorf("Expected on stdout: %q, stdout: %s\n", stdoutMessage, stdout) + } + if !strings.Contains(stderr, stderrMessage) { + t.Errorf("Expected on stderr: %q, stderr: %s\n", stderrMessage, stderr) + } + if len(stdoutExcluded) > 0 && strings.Contains(stdout, stdoutExcluded) { + t.Errorf("Unexpected output on stdout: %q, stdout: %s\n", stdoutExcluded, stdout) + } + if len(stderrExcluded) > 0 && strings.Contains(stderr, stderrExcluded) { + t.Errorf("Unexpected output on stderr: %q, stderr: %s\n", stderrExcluded, stderr) + } +} diff --git a/pkg/cfg/files.go b/pkg/cfg/files.go index 829312e7e..d29179899 100644 --- a/pkg/cfg/files.go +++ b/pkg/cfg/files.go @@ -74,17 +74,10 @@ func YAMLFlag(args []string, name string) Source { // parsing out the config file location. dst.Clone().RegisterFlags(freshFlags) - usage := freshFlags.Usage freshFlags.Usage = func() { /* don't do anything by default, we will print usage ourselves, but only when requested. */ } - err := freshFlags.Parse(args) - if err == flag.ErrHelp { - // print available parameters to stdout, so that users can grep/less it easily - freshFlags.SetOutput(os.Stdout) - usage() - os.Exit(2) - } else if err != nil { - fmt.Fprintln(freshFlags.Output(), "Run with -help to get list of available parameters") + if err := freshFlags.Parse(args); err != nil { + fmt.Fprintln(freshFlags.Output(), "Run with -help to get a list of available parameters") os.Exit(2) } @@ -99,6 +92,5 @@ func YAMLFlag(args []string, name string) Source { } return YAML(f.Value.String(), expandEnv)(dst) - } } diff --git a/pkg/phlare/phlare.go b/pkg/phlare/phlare.go index 2415c54ed..16b2ee2a9 100644 --- a/pkg/phlare/phlare.go +++ b/pkg/phlare/phlare.go @@ -62,7 +62,6 @@ type Config struct { Analytics usagestats.Config `yaml:"analytics"` ConfigFile string `yaml:"-"` - ShowVersion bool `yaml:"-"` ConfigExpandEnv bool `yaml:"-"` } @@ -93,7 +92,6 @@ func (c *Config) RegisterFlagsWithContext(ctx context.Context, f *flag.FlagSet) f.Var(&c.Target, "target", "Comma-separated list of Phlare modules to load. "+ "The alias 'all' can be used in the list to load a number of core modules and will enable single-binary mode. ") f.BoolVar(&c.MultitenancyEnabled, "auth.multitenancy-enabled", false, "When set to true, incoming HTTP requests must specify tenant ID in HTTP X-Scope-OrgId header. When set to false, tenant ID anonymous is used instead.") - f.BoolVar(&c.ShowVersion, "version", false, "Show the version of phlare and exit") f.BoolVar(&c.ConfigExpandEnv, "config.expand-env", false, "Expands ${var} in config according to the values of the environment variables.") c.registerServerFlagsWithChangedDefaultValues(f) @@ -138,13 +136,18 @@ func (c *Config) Validate() error { return c.AgentConfig.Validate() } +type phlareConfigGetter interface { + PhlareConfig() *Config +} + func (c *Config) ApplyDynamicConfig() cfg.Source { c.Ingester.LifecyclerConfig.RingConfig.KVStore.Store = "memberlist" return func(dst cfg.Cloneable) error { - r, ok := dst.(*Config) + g, ok := dst.(phlareConfigGetter) if !ok { - return errors.New("dst is not a Phlare config") + return fmt.Errorf("dst is not a Phlare config getter %T", dst) } + r := g.PhlareConfig() if r.AgentConfig.ClientConfig.URL.String() == "" { listenAddress := "0.0.0.0" if c.Server.HTTPListenAddress != "" { @@ -195,11 +198,6 @@ func New(cfg Config) (*Phlare, error) { logger := initLogger(&cfg.Server) usagestats.Edition("oss") - if cfg.ShowVersion { - fmt.Println(version.Print("phlare")) - os.Exit(0) - } - phlare := &Phlare{ Cfg: cfg, logger: logger, diff --git a/pkg/phlare/phlare_test.go b/pkg/phlare/phlare_test.go index 10a20a263..0744b2ee2 100644 --- a/pkg/phlare/phlare_test.go +++ b/pkg/phlare/phlare_test.go @@ -22,6 +22,9 @@ func TestFlagDefaults(t *testing.T) { f.PrintDefaults() const delim = '\n' + // Because this is a short flag, it will be printed on the same line as the + // flag name. So we need to ignore this special case. + const ignoredHelpFlags = "-h\tPrint basic help." // Populate map with parsed default flags. // Key is the flag and value is the default text. @@ -33,6 +36,10 @@ func TestFlagDefaults(t *testing.T) { } require.NoError(t, err) + if strings.Contains(line, ignoredHelpFlags) { + continue + } + nextLine, err := buf.ReadString(delim) require.NoError(t, err) diff --git a/pkg/test/capture.go b/pkg/test/capture.go new file mode 100644 index 000000000..dc6161256 --- /dev/null +++ b/pkg/test/capture.go @@ -0,0 +1,68 @@ +package test + +import ( + "bytes" + "io" + "os" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +type CapturedOutput struct { + stdoutBuf bytes.Buffer + stderrBuf bytes.Buffer + + wg sync.WaitGroup + stdoutReader, stdoutWriter *os.File + stderrReader, stderrWriter *os.File +} + +// CaptureOutput replaces os.Stdout and os.Stderr with new pipes, that will +// write output to buffers. Buffers are accessible by calling Done on returned +// struct. +// +// os.Stdout and os.Stderr must be reverted to previous values manually. +func CaptureOutput(t *testing.T) *CapturedOutput { + stdoutR, stdoutW, err := os.Pipe() + require.NoError(t, err) + + stderrR, stderrW, err := os.Pipe() + require.NoError(t, err) + + os.Stdout = stdoutW + os.Stderr = stderrW + + co := &CapturedOutput{ + stdoutReader: stdoutR, + stdoutWriter: stdoutW, + stderrReader: stderrR, + stderrWriter: stderrW, + } + co.wg.Add(1) + go func() { + defer co.wg.Done() + _, _ = io.Copy(&co.stdoutBuf, stdoutR) + }() + + co.wg.Add(1) + go func() { + defer co.wg.Done() + _, _ = io.Copy(&co.stderrBuf, stderrR) + }() + + return co +} + +// Done waits until all captured output has been written to buffers, +// and then returns the buffers. +func (co *CapturedOutput) Done() (stdout string, stderr string) { + // we need to close writers for readers to stop + _ = co.stdoutWriter.Close() + _ = co.stderrWriter.Close() + + co.wg.Wait() + + return co.stdoutBuf.String(), co.stderrBuf.String() +} diff --git a/pkg/usage/usage.go b/pkg/usage/usage.go new file mode 100644 index 000000000..c1c891aab --- /dev/null +++ b/pkg/usage/usage.go @@ -0,0 +1,245 @@ +//nolint:goconst +package usage + +import ( + "flag" + "fmt" + "os" + "reflect" + "strings" + + "github.com/grafana/dskit/flagext" + + "github.com/grafana/phlare/pkg/util/fieldcategory" +) + +// Usage prints command-line usage. +// printAll controls whether only basic flags or all flags are included. +// configs are expected to be pointers to structs. +func Usage(printAll bool, configs ...interface{}) error { + fields := map[uintptr]reflect.StructField{} + for _, c := range configs { + if err := parseStructure(c, fields); err != nil { + return err + } + } + + fs := flag.CommandLine + fmt.Fprintf(fs.Output(), "Usage of %s:\n", os.Args[0]) + fs.VisitAll(func(fl *flag.Flag) { + v := reflect.ValueOf(fl.Value) + fieldCat := fieldcategory.Basic + var field reflect.StructField + + // Do not print usage for deprecated flags. + if fl.Value.String() == "deprecated" { + return + } + + if override, ok := fieldcategory.GetOverride(fl.Name); ok { + fieldCat = override + } else if v.Kind() == reflect.Ptr { + ptr := v.Pointer() + field, ok = fields[ptr] + if ok { + catStr := field.Tag.Get("category") + switch catStr { + case "advanced": + fieldCat = fieldcategory.Advanced + case "experimental": + fieldCat = fieldcategory.Experimental + } + } + } + + if fieldCat != fieldcategory.Basic && !printAll { + // Don't print help for this flag since we're supposed to print only basic flags + return + } + + var b strings.Builder + // Two spaces before -; see next two comments. + fmt.Fprintf(&b, " -%s", fl.Name) + name := getFlagName(fl) + if len(name) > 0 { + b.WriteString(" ") + b.WriteString(strings.ReplaceAll(name, " ", "-")) + } + // Four spaces before the tab triggers good alignment + // for both 4- and 8-space tab stops. + b.WriteString("\n \t") + if fieldCat == fieldcategory.Experimental { + b.WriteString("[experimental] ") + } + b.WriteString(strings.ReplaceAll(fl.Usage, "\n", "\n \t")) + + if defValue := getFlagDefault(fl, field); !isZeroValue(fl, defValue) { + v := reflect.ValueOf(fl.Value) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.String { + // put quotes on the value + fmt.Fprintf(&b, " (default %q)", defValue) + } else { + fmt.Fprintf(&b, " (default %v)", defValue) + } + } + fmt.Fprint(fs.Output(), b.String(), "\n") + }) + + if !printAll { + fmt.Fprintf(fs.Output(), "\nTo see all flags, use -help-all\n") + } + + return nil +} + +// isZeroValue determines whether the string represents the zero +// value for a flag. +func isZeroValue(fl *flag.Flag, value string) bool { + // Build a zero value of the flag's Value type, and see if the + // result of calling its String method equals the value passed in. + // This works unless the Value type is itself an interface type. + typ := reflect.TypeOf(fl.Value) + var z reflect.Value + if typ.Kind() == reflect.Ptr { + z = reflect.New(typ.Elem()) + } else { + z = reflect.Zero(typ) + } + return value == z.Interface().(flag.Value).String() +} + +// parseStructure parses a struct and populates fields. +func parseStructure(structure interface{}, fields map[uintptr]reflect.StructField) error { + // structure is expected to be a pointer to a struct + if reflect.TypeOf(structure).Kind() != reflect.Ptr { + t := reflect.TypeOf(structure) + return fmt.Errorf("%s is a %s while a %s is expected", t, t.Kind(), reflect.Ptr) + } + v := reflect.ValueOf(structure).Elem() + if v.Kind() != reflect.Struct { + return fmt.Errorf("%s is a %s while a %s is expected", v, v.Kind(), reflect.Struct) + } + + t := v.Type() + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type.Kind() == reflect.Func { + continue + } + + fieldValue := v.FieldByIndex(field.Index) + + // Take address of field value and map it to field + fields[fieldValue.Addr().Pointer()] = field + + // Recurse if a struct + if field.Type.Kind() != reflect.Struct || ignoreStructType(field.Type) || !field.IsExported() { + continue + } + + if err := parseStructure(fieldValue.Addr().Interface(), fields); err != nil { + return err + } + } + + return nil +} + +// Descending into some structs breaks check for "advanced" category for some fields (eg. flagext.Secret), +// because field itself is at the same memory address as the internal field in the struct, and advanced-category-check +// then gets confused. +var ignoredStructTypes = []reflect.Type{ + reflect.TypeOf(flagext.Secret{}), +} + +func ignoreStructType(fieldType reflect.Type) bool { + for _, t := range ignoredStructTypes { + if fieldType == t { + return true + } + } + return false +} + +func getFlagName(fl *flag.Flag) string { + if getter, ok := fl.Value.(flag.Getter); ok { + if v := reflect.ValueOf(getter.Get()); v.IsValid() { + t := v.Type() + switch t.Name() { + case "bool": + return "" + case "Duration": + return "duration" + case "float64": + return "float" + case "int", "int64": + return "int" + case "string": + return "string" + case "uint", "uint64": + return "uint" + case "Secret": + return "string" + default: + return "value" + } + } + } + + // Check custom types. + if v := reflect.ValueOf(fl.Value); v.IsValid() { + switch v.Type().String() { + case "*flagext.Secret": + return "string" + case "*flagext.StringSlice": + return "string" + case "*flagext.StringSliceCSV": + return "comma-separated list of strings" + case "*flagext.CIDRSliceCSV": + return "comma-separated list of strings" + case "*flagext.URLValue": + return "string" + case "*url.URL": + return "string" + case "*model.Duration": + return "duration" + case "*tsdb.DurationList": + return "comma-separated list of durations" + } + } + + return "value" +} + +func getFlagDefault(fl *flag.Flag, field reflect.StructField) string { + if docDefault := parseDocTag(field)["default"]; docDefault != "" { + return docDefault + } + return fl.DefValue +} + +func parseDocTag(f reflect.StructField) map[string]string { + cfg := map[string]string{} + tag := f.Tag.Get("doc") + + if tag == "" { + return cfg + } + + for _, entry := range strings.Split(tag, "|") { + parts := strings.SplitN(entry, "=", 2) + + switch len(parts) { + case 1: + cfg[parts[0]] = "" + case 2: + cfg[parts[0]] = parts[1] + } + } + + return cfg +} diff --git a/pkg/util/fieldcategory/overrides.go b/pkg/util/fieldcategory/overrides.go new file mode 100644 index 000000000..bdaee77c7 --- /dev/null +++ b/pkg/util/fieldcategory/overrides.go @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: AGPL-3.0-only + +package fieldcategory + +import "fmt" + +type Category int + +const ( + // Basic is the basic field category, and the default if none is defined. + Basic Category = iota + // Advanced is the advanced field category. + Advanced + // Experimental is the experimental field category. + Experimental +) + +func (c Category) String() string { + switch c { + case Basic: + return "basic" + case Advanced: + return "advanced" + case Experimental: + return "experimental" + default: + panic(fmt.Sprintf("Unknown field category: %d", c)) + } +} + +// Fields are primarily categorized via struct tags, but this can be impossible when third party libraries are involved +// Only categorize fields here when you can't otherwise, since struct tags are less likely to become stale +var overrides = map[string]Category{} + +func AddOverrides(o map[string]Category) { + for n, c := range o { + overrides[n] = c + } +} + +func GetOverride(fieldName string) (category Category, ok bool) { + category, ok = overrides[fieldName] + return +} + +func VisitOverrides(f func(name string)) { + for override := range overrides { + f(override) + } +}