From f183c5b76dd265a1669b406531be4868a935ffd9 Mon Sep 17 00:00:00 2001 From: Francesco Cheinasso Date: Mon, 24 Jun 2024 11:26:44 +0200 Subject: [PATCH] Flags refactoring + custom MTU Signed-off-by: Francesco Cheinasso Signed-off-by: Francesco Cheinasso --- flags/flags.go | 39 ++++++++++++++++++++++++++++++++++++ flags/options.go | 13 ++++++++++++ go.mod | 1 + go.sum | 2 ++ main.go | 51 ++++++++++++++---------------------------------- 5 files changed, 70 insertions(+), 36 deletions(-) create mode 100644 flags/flags.go create mode 100644 flags/options.go diff --git a/flags/flags.go b/flags/flags.go new file mode 100644 index 000000000..c1c638979 --- /dev/null +++ b/flags/flags.go @@ -0,0 +1,39 @@ +package flags + +import ( + "fmt" + "os" + + "github.com/spf13/pflag" + "golang.zx2c4.com/wireguard/device" +) + +func Parse(opts *Options) error { + pflag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [flags] \n", os.Args[0]) + pflag.PrintDefaults() + } + + pflag.IntVar(&opts.MTU, "mtu", device.DefaultMTU, "Set the MTU of the device") + pflag.BoolVar(&opts.Foreground, "foreground", false, "Remain in the foreground") + pflag.BoolVarP(&opts.ShowVersion, "version", "v", false, "Print the version number and exit") + + pflag.Parse() + + if opts.ShowVersion { + return nil + } + + if err := setInterfaceName(opts); err != nil { + return err + } + return nil +} + +func setInterfaceName(opts *Options) error { + if pflag.NArg() != 1 { + return fmt.Errorf("Must pass exactly one interface name, but got %d", pflag.NArg()) + } + opts.InterfaceName = pflag.Arg(0) + return nil +} diff --git a/flags/options.go b/flags/options.go new file mode 100644 index 000000000..aa0a4e0df --- /dev/null +++ b/flags/options.go @@ -0,0 +1,13 @@ +package flags + +type Options struct { + InterfaceName string + + MTU int + Foreground bool + ShowVersion bool +} + +func NewOptions() *Options { + return &Options{} +} diff --git a/go.mod b/go.mod index 919dc4927..d7572e243 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module golang.zx2c4.com/wireguard go 1.20 require ( + github.com/spf13/pflag v1.0.5 golang.org/x/crypto v0.13.0 golang.org/x/net v0.15.0 golang.org/x/sys v0.12.0 diff --git a/go.sum b/go.sum index 6bcecea3f..28d7f9d25 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= diff --git a/main.go b/main.go index e01611694..adf84e6bd 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ import ( "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/flags" "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" ) @@ -32,10 +33,6 @@ const ( ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" ) -func printUsage() { - fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) -} - func warning() { switch runtime.GOOS { case "linux", "freebsd", "openbsd": @@ -58,41 +55,21 @@ func warning() { } func main() { - if len(os.Args) == 2 && os.Args[1] == "--version" { - fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld .\n", Version, runtime.GOOS, runtime.GOARCH) - return + opts := flags.NewOptions() + if err := flags.Parse(opts); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(ExitSetupFailed) } - warning() - - var foreground bool - var interfaceName string - if len(os.Args) < 2 || len(os.Args) > 3 { - printUsage() + if opts.ShowVersion { + fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld .\n", Version, runtime.GOOS, runtime.GOARCH) return } - switch os.Args[1] { - - case "-f", "--foreground": - foreground = true - if len(os.Args) != 3 { - printUsage() - return - } - interfaceName = os.Args[2] - - default: - foreground = false - if len(os.Args) != 2 { - printUsage() - return - } - interfaceName = os.Args[1] - } + warning() - if !foreground { - foreground = os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" + if !opts.Foreground { + opts.Foreground = os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" } // get log level (default: info) @@ -111,10 +88,12 @@ func main() { // open TUN device (or use supplied fd) + interfaceName := opts.InterfaceName + tdev, err := func() (tun.Device, error) { tunFdStr := os.Getenv(ENV_WG_TUN_FD) if tunFdStr == "" { - return tun.CreateTUN(interfaceName, device.DefaultMTU) + return tun.CreateTUN(interfaceName, opts.MTU) } // construct tun device from supplied fd @@ -130,7 +109,7 @@ func main() { } file := os.NewFile(uintptr(fd), "") - return tun.CreateTUNFromFile(file, device.DefaultMTU) + return tun.CreateTUNFromFile(file, opts.MTU) }() if err == nil { @@ -176,7 +155,7 @@ func main() { } // daemonize the process - if !foreground { + if !opts.Foreground { env := os.Environ() env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))