Skip to content

Commit

Permalink
Simplify restart logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nkryuchkov committed Dec 26, 2019
1 parent d017376 commit a754316
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 111 deletions.
7 changes: 1 addition & 6 deletions cmd/skywire-visor/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,7 @@ func init() {
rootCmd.Flags().StringVarP(&cfg.port, "port", "", "6060", "port for http-mode of pprof")
rootCmd.Flags().StringVarP(&cfg.startDelay, "delay", "", "0ns", "delay before visor start")

restartCtx, err := restart.CaptureContext()
if err != nil {
log.Printf("Failed to capture context: %v", err)
} else {
cfg.restartCtx = restartCtx
}
cfg.restartCtx = restart.CaptureContext()
}

// Execute executes root CLI command.
Expand Down
113 changes: 40 additions & 73 deletions pkg/restart/restart.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,49 @@ import (
"log"
"os"
"os/exec"
"path/filepath"
"sync/atomic"
"time"

"github.com/sirupsen/logrus"
)

var (
// ErrMalformedArgs is returned when executable args are malformed.
ErrMalformedArgs = errors.New("malformed args")
// ErrAlreadyRestarting is returned on restarting attempt when restarting is in progress.
ErrAlreadyRestarting = errors.New("already restarting")
// ErrAlreadyStarting is returned on starting attempt when starting is in progress.
ErrAlreadyStarting = errors.New("already starting")
)

// DefaultCheckDelay is a default delay for checking if a new instance is started successfully.
const DefaultCheckDelay = 1 * time.Second
const (
// DefaultCheckDelay is a default delay for checking if a new instance is started successfully.
DefaultCheckDelay = 1 * time.Second
extraWaitingTime = 1 * time.Second
delayArgName = "--delay"
)

// Context describes data required for restarting visor.
type Context struct {
log logrus.FieldLogger
checkDelay time.Duration
workingDirectory string
args []string
isRestarting int32
appendDelay bool // disabled in tests
log logrus.FieldLogger
isStarting int32
checkDelay time.Duration
appendDelay bool // disabled in tests
cmd *exec.Cmd
}

// CaptureContext captures data required for restarting visor.
// Data used by CaptureContext must not be modified before,
// therefore calling CaptureContext immediately after starting executable is recommended.
func CaptureContext() (*Context, error) {
wd, err := os.Getwd()
if err != nil {
return nil, err
func CaptureContext() *Context {
cmd := exec.Command(os.Args[0], os.Args[1:]...)

cmd.Stdout = os.Stdout
cmd.Stdin = os.Stdin
cmd.Stderr = os.Stderr
cmd.Env = os.Environ()

return &Context{
cmd: cmd,
checkDelay: DefaultCheckDelay,
appendDelay: true,
}

args := os.Args

context := &Context{
checkDelay: DefaultCheckDelay,
workingDirectory: wd,
args: args,
appendDelay: true,
}

return context, nil
}

// RegisterLogger registers a logger instead of standard one.
Expand All @@ -69,22 +66,13 @@ func (c *Context) SetCheckDelay(delay time.Duration) {

// Start starts a new executable using Context.
func (c *Context) Start() error {
if !atomic.CompareAndSwapInt32(&c.isRestarting, 0, 1) {
return ErrAlreadyRestarting
if !atomic.CompareAndSwapInt32(&c.isStarting, 0, 1) {
return ErrAlreadyStarting
}

defer atomic.StoreInt32(&c.isRestarting, 0)
defer atomic.StoreInt32(&c.isStarting, 0)

if len(c.args) == 0 {
return ErrMalformedArgs
}

execPath := c.args[0]
if !filepath.IsAbs(execPath) {
execPath = filepath.Join(c.workingDirectory, execPath)
}

errCh := c.startExec(execPath)
errCh := c.startExec()

ticker := time.NewTicker(c.checkDelay)
defer ticker.Stop()
Expand All @@ -94,58 +82,37 @@ func (c *Context) Start() error {
c.errorLogger()("Failed to start new instance: %v", err)
return err
case <-ticker.C:
c.infoLogger()("New instance started successfully, exiting")
c.infoLogger()("New instance started successfully, exiting from the old one")
return nil
}
}

func (c *Context) startExec(path string) chan error {
func (c *Context) startExec() chan error {
errCh := make(chan error, 1)

go func(path string) {
go func() {
defer close(errCh)

normalizedPath, err := exec.LookPath(path)
if err != nil {
errCh <- err
return
}

if len(c.args) == 0 {
errCh <- ErrMalformedArgs
return
}
c.adjustArgs()

args := c.startArgs()
cmd := exec.Command(normalizedPath, args...) // nolint:gosec
c.infoLogger()("Starting new instance of executable (args: %q)", c.cmd.Args)

cmd.Stdout = os.Stdout
cmd.Stdin = os.Stdin
cmd.Stderr = os.Stderr
cmd.Env = os.Environ()

c.infoLogger()("Starting new instance of executable (path: %q, args: %q)", path, args)

if err := cmd.Start(); err != nil {
if err := c.cmd.Start(); err != nil {
errCh <- err
return
}

if err := cmd.Wait(); err != nil {
if err := c.cmd.Wait(); err != nil {
errCh <- err
return
}
}(path)
}()

return errCh
}

const extraWaitingTime = 1 * time.Second

func (c *Context) startArgs() []string {
args := c.args[1:]

const delayArgName = "--delay"
func (c *Context) adjustArgs() {
args := c.cmd.Args

l := len(args)
for i := 0; i < l; i++ {
Expand All @@ -161,7 +128,7 @@ func (c *Context) startArgs() []string {
args = append(args, delayArgName, delay.String())
}

return args
c.cmd.Args = args
}

func (c *Context) infoLogger() func(string, ...interface{}) {
Expand Down
48 changes: 16 additions & 32 deletions pkg/restart/restart_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package restart

import (
"os"
"os/exec"
"testing"
"time"

Expand All @@ -11,21 +12,19 @@ import (
)

func TestCaptureContext(t *testing.T) {
cc, err := CaptureContext()
require.NoError(t, err)
cc := CaptureContext()

wd, err := os.Getwd()
assert.NoError(t, err)

require.Equal(t, wd, cc.workingDirectory)
require.Equal(t, DefaultCheckDelay, cc.checkDelay)
require.Equal(t, os.Args, cc.args)
require.Equal(t, os.Args, cc.cmd.Args)
require.Equal(t, os.Stdout, cc.cmd.Stdout)
require.Equal(t, os.Stdin, cc.cmd.Stdin)
require.Equal(t, os.Stderr, cc.cmd.Stderr)
require.Equal(t, os.Environ(), cc.cmd.Env)
require.Nil(t, cc.log)
}

func TestContext_RegisterLogger(t *testing.T) {
cc, err := CaptureContext()
require.NoError(t, err)
cc := CaptureContext()
require.Nil(t, cc.log)

logger := logging.MustGetLogger("test")
Expand All @@ -34,17 +33,13 @@ func TestContext_RegisterLogger(t *testing.T) {
}

func TestContext_Start(t *testing.T) {
cc, err := CaptureContext()
require.NoError(t, err)
assert.NotZero(t, len(cc.args))

cc.workingDirectory = ""
cc := CaptureContext()
assert.NotZero(t, len(cc.cmd.Args))

t.Run("executable started", func(t *testing.T) {
cmd := "touch"
path := "/tmp/test_restart"
args := []string{cmd, path}
cc.args = args
cc.cmd = exec.Command(cmd, path)
cc.appendDelay = false

assert.NoError(t, cc.Start())
Expand All @@ -53,26 +48,16 @@ func TestContext_Start(t *testing.T) {

t.Run("bad args", func(t *testing.T) {
cmd := "bad_command"
args := []string{cmd}
cc.args = args
cc.cmd = exec.Command(cmd)

// TODO(nkryuchkov): Check if it works on Linux and Windows, if not then change the error text.
assert.EqualError(t, cc.Start(), `exec: "bad_command": executable file not found in $PATH`)
})

t.Run("empty args", func(t *testing.T) {
cc.args = nil

assert.Equal(t, ErrMalformedArgs, cc.Start())
})

t.Run("already restarting", func(t *testing.T) {
cc.args = nil

cmd := "touch"
path := "/tmp/test_restart"
args := []string{cmd, path}
cc.args = args
cc.cmd = exec.Command(cmd, path)
cc.appendDelay = false

ch := make(chan error, 1)
Expand All @@ -81,15 +66,14 @@ func TestContext_Start(t *testing.T) {
}()

assert.NoError(t, cc.Start())
assert.NoError(t, os.Remove(path))
assert.Equal(t, ErrAlreadyStarting, <-ch)

assert.Equal(t, ErrAlreadyRestarting, <-ch)
assert.NoError(t, os.Remove(path))
})
}

func TestContext_SetCheckDelay(t *testing.T) {
cc, err := CaptureContext()
require.NoError(t, err)
cc := CaptureContext()
require.Equal(t, DefaultCheckDelay, cc.checkDelay)

const oneSecond = 1 * time.Second
Expand Down

0 comments on commit a754316

Please sign in to comment.