diff --git a/cmd/skywire-visor/commands/root.go b/cmd/skywire-visor/commands/root.go index 9eb51f5e75..ed47be7292 100644 --- a/cmd/skywire-visor/commands/root.go +++ b/cmd/skywire-visor/commands/root.go @@ -78,12 +78,7 @@ func init() { rootCmd.Flags().StringVarP(&cfg.port, "port", "", "6060", "port for http-mode of pprof") rootCmd.Flags().StringVarP(&cfg.startDelay, "delay", "", "0ns", "delay before visor start") - restartCtx, err := restart.CaptureContext() - if err != nil { - log.Printf("Failed to capture context: %v", err) - } else { - cfg.restartCtx = restartCtx - } + cfg.restartCtx = restart.CaptureContext() } // Execute executes root CLI command. diff --git a/pkg/restart/restart.go b/pkg/restart/restart.go index b26016c6c1..1313097ec8 100644 --- a/pkg/restart/restart.go +++ b/pkg/restart/restart.go @@ -5,7 +5,6 @@ import ( "log" "os" "os/exec" - "path/filepath" "sync/atomic" "time" @@ -13,44 +12,42 @@ import ( ) var ( - // ErrMalformedArgs is returned when executable args are malformed. - ErrMalformedArgs = errors.New("malformed args") - // ErrAlreadyRestarting is returned on restarting attempt when restarting is in progress. - ErrAlreadyRestarting = errors.New("already restarting") + // ErrAlreadyStarting is returned on starting attempt when starting is in progress. + ErrAlreadyStarting = errors.New("already starting") ) -// DefaultCheckDelay is a default delay for checking if a new instance is started successfully. -const DefaultCheckDelay = 1 * time.Second +const ( + // DefaultCheckDelay is a default delay for checking if a new instance is started successfully. + DefaultCheckDelay = 1 * time.Second + extraWaitingTime = 1 * time.Second + delayArgName = "--delay" +) // Context describes data required for restarting visor. type Context struct { - log logrus.FieldLogger - checkDelay time.Duration - workingDirectory string - args []string - isRestarting int32 - appendDelay bool // disabled in tests + log logrus.FieldLogger + isStarting int32 + checkDelay time.Duration + appendDelay bool // disabled in tests + cmd *exec.Cmd } // CaptureContext captures data required for restarting visor. // Data used by CaptureContext must not be modified before, // therefore calling CaptureContext immediately after starting executable is recommended. -func CaptureContext() (*Context, error) { - wd, err := os.Getwd() - if err != nil { - return nil, err +func CaptureContext() *Context { + cmd := exec.Command(os.Args[0], os.Args[1:]...) + + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + cmd.Stderr = os.Stderr + cmd.Env = os.Environ() + + return &Context{ + cmd: cmd, + checkDelay: DefaultCheckDelay, + appendDelay: true, } - - args := os.Args - - context := &Context{ - checkDelay: DefaultCheckDelay, - workingDirectory: wd, - args: args, - appendDelay: true, - } - - return context, nil } // RegisterLogger registers a logger instead of standard one. @@ -69,22 +66,13 @@ func (c *Context) SetCheckDelay(delay time.Duration) { // Start starts a new executable using Context. func (c *Context) Start() error { - if !atomic.CompareAndSwapInt32(&c.isRestarting, 0, 1) { - return ErrAlreadyRestarting + if !atomic.CompareAndSwapInt32(&c.isStarting, 0, 1) { + return ErrAlreadyStarting } - defer atomic.StoreInt32(&c.isRestarting, 0) + defer atomic.StoreInt32(&c.isStarting, 0) - if len(c.args) == 0 { - return ErrMalformedArgs - } - - execPath := c.args[0] - if !filepath.IsAbs(execPath) { - execPath = filepath.Join(c.workingDirectory, execPath) - } - - errCh := c.startExec(execPath) + errCh := c.startExec() ticker := time.NewTicker(c.checkDelay) defer ticker.Stop() @@ -94,58 +82,37 @@ func (c *Context) Start() error { c.errorLogger()("Failed to start new instance: %v", err) return err case <-ticker.C: - c.infoLogger()("New instance started successfully, exiting") + c.infoLogger()("New instance started successfully, exiting from the old one") return nil } } -func (c *Context) startExec(path string) chan error { +func (c *Context) startExec() chan error { errCh := make(chan error, 1) - go func(path string) { + go func() { defer close(errCh) - normalizedPath, err := exec.LookPath(path) - if err != nil { - errCh <- err - return - } - - if len(c.args) == 0 { - errCh <- ErrMalformedArgs - return - } + c.adjustArgs() - args := c.startArgs() - cmd := exec.Command(normalizedPath, args...) // nolint:gosec + c.infoLogger()("Starting new instance of executable (args: %q)", c.cmd.Args) - cmd.Stdout = os.Stdout - cmd.Stdin = os.Stdin - cmd.Stderr = os.Stderr - cmd.Env = os.Environ() - - c.infoLogger()("Starting new instance of executable (path: %q, args: %q)", path, args) - - if err := cmd.Start(); err != nil { + if err := c.cmd.Start(); err != nil { errCh <- err return } - if err := cmd.Wait(); err != nil { + if err := c.cmd.Wait(); err != nil { errCh <- err return } - }(path) + }() return errCh } -const extraWaitingTime = 1 * time.Second - -func (c *Context) startArgs() []string { - args := c.args[1:] - - const delayArgName = "--delay" +func (c *Context) adjustArgs() { + args := c.cmd.Args l := len(args) for i := 0; i < l; i++ { @@ -161,7 +128,7 @@ func (c *Context) startArgs() []string { args = append(args, delayArgName, delay.String()) } - return args + c.cmd.Args = args } func (c *Context) infoLogger() func(string, ...interface{}) { diff --git a/pkg/restart/restart_test.go b/pkg/restart/restart_test.go index 29cbb6f962..e75c880795 100644 --- a/pkg/restart/restart_test.go +++ b/pkg/restart/restart_test.go @@ -2,6 +2,7 @@ package restart import ( "os" + "os/exec" "testing" "time" @@ -11,21 +12,19 @@ import ( ) func TestCaptureContext(t *testing.T) { - cc, err := CaptureContext() - require.NoError(t, err) + cc := CaptureContext() - wd, err := os.Getwd() - assert.NoError(t, err) - - require.Equal(t, wd, cc.workingDirectory) require.Equal(t, DefaultCheckDelay, cc.checkDelay) - require.Equal(t, os.Args, cc.args) + require.Equal(t, os.Args, cc.cmd.Args) + require.Equal(t, os.Stdout, cc.cmd.Stdout) + require.Equal(t, os.Stdin, cc.cmd.Stdin) + require.Equal(t, os.Stderr, cc.cmd.Stderr) + require.Equal(t, os.Environ(), cc.cmd.Env) require.Nil(t, cc.log) } func TestContext_RegisterLogger(t *testing.T) { - cc, err := CaptureContext() - require.NoError(t, err) + cc := CaptureContext() require.Nil(t, cc.log) logger := logging.MustGetLogger("test") @@ -34,17 +33,13 @@ func TestContext_RegisterLogger(t *testing.T) { } func TestContext_Start(t *testing.T) { - cc, err := CaptureContext() - require.NoError(t, err) - assert.NotZero(t, len(cc.args)) - - cc.workingDirectory = "" + cc := CaptureContext() + assert.NotZero(t, len(cc.cmd.Args)) t.Run("executable started", func(t *testing.T) { cmd := "touch" path := "/tmp/test_restart" - args := []string{cmd, path} - cc.args = args + cc.cmd = exec.Command(cmd, path) cc.appendDelay = false assert.NoError(t, cc.Start()) @@ -53,26 +48,16 @@ func TestContext_Start(t *testing.T) { t.Run("bad args", func(t *testing.T) { cmd := "bad_command" - args := []string{cmd} - cc.args = args + cc.cmd = exec.Command(cmd) // TODO(nkryuchkov): Check if it works on Linux and Windows, if not then change the error text. assert.EqualError(t, cc.Start(), `exec: "bad_command": executable file not found in $PATH`) }) - t.Run("empty args", func(t *testing.T) { - cc.args = nil - - assert.Equal(t, ErrMalformedArgs, cc.Start()) - }) - t.Run("already restarting", func(t *testing.T) { - cc.args = nil - cmd := "touch" path := "/tmp/test_restart" - args := []string{cmd, path} - cc.args = args + cc.cmd = exec.Command(cmd, path) cc.appendDelay = false ch := make(chan error, 1) @@ -81,15 +66,14 @@ func TestContext_Start(t *testing.T) { }() assert.NoError(t, cc.Start()) - assert.NoError(t, os.Remove(path)) + assert.Equal(t, ErrAlreadyStarting, <-ch) - assert.Equal(t, ErrAlreadyRestarting, <-ch) + assert.NoError(t, os.Remove(path)) }) } func TestContext_SetCheckDelay(t *testing.T) { - cc, err := CaptureContext() - require.NoError(t, err) + cc := CaptureContext() require.Equal(t, DefaultCheckDelay, cc.checkDelay) const oneSecond = 1 * time.Second