diff --git a/pkg/restart/restart.go b/pkg/restart/restart.go index 0c485ebb8d..94a61922e2 100644 --- a/pkg/restart/restart.go +++ b/pkg/restart/restart.go @@ -12,8 +12,8 @@ import ( ) var ( - // ErrAlreadyStarting is returned on starting attempt when starting is in progress. - ErrAlreadyStarting = errors.New("already starting") + // ErrAlreadyStarted is returned when Start is already called. + ErrAlreadyStarted = errors.New("already started") ) const ( @@ -28,7 +28,7 @@ type Context struct { log logrus.FieldLogger cmd *exec.Cmd checkDelay time.Duration - isStarting int32 + isStarted int32 appendDelay bool // disabled in tests } @@ -66,12 +66,10 @@ func (c *Context) SetCheckDelay(delay time.Duration) { // Start starts a new executable using Context. func (c *Context) Start() error { - if !atomic.CompareAndSwapInt32(&c.isStarting, 0, 1) { - return ErrAlreadyStarting + if !atomic.CompareAndSwapInt32(&c.isStarted, 0, 1) { + return ErrAlreadyStarted } - defer atomic.StoreInt32(&c.isStarting, 0) - errCh := c.startExec() ticker := time.NewTicker(c.checkDelay) diff --git a/pkg/restart/restart_test.go b/pkg/restart/restart_test.go index c29defd5fe..3442bc5eb4 100644 --- a/pkg/restart/restart_test.go +++ b/pkg/restart/restart_test.go @@ -33,10 +33,10 @@ func TestContext_RegisterLogger(t *testing.T) { } func TestContext_Start(t *testing.T) { - cc := CaptureContext() - assert.NotZero(t, len(cc.cmd.Args)) - t.Run("executable started", func(t *testing.T) { + cc := CaptureContext() + assert.NotZero(t, len(cc.cmd.Args)) + cmd := "touch" path := "/tmp/test_start" cc.cmd = exec.Command(cmd, path) // nolint:gosec @@ -47,6 +47,9 @@ func TestContext_Start(t *testing.T) { }) t.Run("bad args", func(t *testing.T) { + cc := CaptureContext() + assert.NotZero(t, len(cc.cmd.Args)) + cmd := "bad_command" cc.cmd = exec.Command(cmd) // nolint:gosec @@ -60,6 +63,9 @@ func TestContext_Start(t *testing.T) { }) t.Run("already starting", func(t *testing.T) { + cc := CaptureContext() + assert.NotZero(t, len(cc.cmd.Args)) + cmd := "touch" path := "/tmp/test_start" cc.cmd = exec.Command(cmd, path) // nolint:gosec @@ -74,7 +80,7 @@ func TestContext_Start(t *testing.T) { err2 := <-errCh errors := []error{err1, err2} - assert.Contains(t, errors, ErrAlreadyStarting) + assert.Contains(t, errors, ErrAlreadyStarted) assert.Contains(t, errors, nil) assert.NoError(t, os.Remove(path))