Skip to content

Commit

Permalink
Merge pull request #80 from nkryuchkov/feature/restart-visor
Browse files Browse the repository at this point in the history
Implement visor restart from hypervisor
  • Loading branch information
Darkren authored Dec 27, 2019
2 parents 6e30191 + df0797f commit 137c9c3
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 7 deletions.
22 changes: 21 additions & 1 deletion cmd/skywire-visor/commands/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/spf13/cobra"

"github.com/SkycoinProject/skywire-mainnet/internal/utclient"
"github.com/SkycoinProject/skywire-mainnet/pkg/restart"
"github.com/SkycoinProject/skywire-mainnet/pkg/util/pathutil"
"github.com/SkycoinProject/skywire-mainnet/pkg/visor"
)
Expand All @@ -39,13 +40,15 @@ type runCfg struct {
cfgFromStdin bool
profileMode string
port string
startDelay string
args []string

profileStop func()
logger *logging.Logger
masterLogger *logging.MasterLogger
conf visor.Config
node *visor.Node
restartCtx *restart.Context
}

var cfg *runCfg
Expand Down Expand Up @@ -73,6 +76,9 @@ func init() {
rootCmd.Flags().BoolVarP(&cfg.cfgFromStdin, "stdin", "i", false, "read config from STDIN")
rootCmd.Flags().StringVarP(&cfg.profileMode, "profile", "p", "none", "enable profiling with pprof. Mode: none or one of: [cpu, mem, mutex, block, trace, http]")
rootCmd.Flags().StringVarP(&cfg.port, "port", "", "6060", "port for http-mode of pprof")
rootCmd.Flags().StringVarP(&cfg.startDelay, "delay", "", "0ns", "delay before visor start")

cfg.restartCtx = restart.CaptureContext()
}

// Execute executes root CLI command.
Expand Down Expand Up @@ -148,7 +154,19 @@ func (cfg *runCfg) readConfig() *runCfg {
}

func (cfg *runCfg) runNode() *runCfg {
node, err := visor.NewNode(&cfg.conf, cfg.masterLogger)
startDelay, err := time.ParseDuration(cfg.startDelay)
if err != nil {
cfg.logger.Warnf("Using no visor start delay due to parsing failure: %v", err)
startDelay = time.Duration(0)
}

if startDelay != 0 {
cfg.logger.Infof("Visor start delay is %v, waiting...", startDelay)
}

time.Sleep(startDelay)

node, err := visor.NewNode(&cfg.conf, cfg.masterLogger, cfg.restartCtx)
if err != nil {
cfg.logger.Fatal("Failed to initialize node: ", err)
}
Expand Down Expand Up @@ -181,7 +199,9 @@ func (cfg *runCfg) runNode() *runCfg {
if cfg.conf.ShutdownTimeout == 0 {
cfg.conf.ShutdownTimeout = defaultShutdownTimeout
}

cfg.node = node

return cfg
}

Expand Down
13 changes: 13 additions & 0 deletions pkg/hypervisor/hypervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ func (m *Node) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.Put("/nodes/{pk}/routes/{rid}", m.putRoute())
r.Delete("/nodes/{pk}/routes/{rid}", m.deleteRoute())
r.Get("/nodes/{pk}/loops", m.getLoops())
r.Get("/nodes/{pk}/restart", m.restart())
})
})
r.ServeHTTP(w, req)
Expand Down Expand Up @@ -569,6 +570,18 @@ func (m *Node) getLoops() http.HandlerFunc {
})
}

// NOTE: Reply comes with a delay, because of check if new executable is started successfully.
func (m *Node) restart() http.HandlerFunc {
return m.withCtx(m.nodeCtx, func(w http.ResponseWriter, r *http.Request, ctx *httpCtx) {
if err := ctx.RPC.Restart(); err != nil {
httputil.WriteJSON(w, r, http.StatusInternalServerError, err)
return
}

httputil.WriteJSON(w, r, http.StatusOK, true)
})
}

/*
<<< Helper functions >>>
*/
Expand Down
155 changes: 155 additions & 0 deletions pkg/restart/restart.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package restart

import (
"errors"
"log"
"os"
"os/exec"
"sync/atomic"
"time"

"github.com/sirupsen/logrus"
)

var (
// ErrAlreadyStarting is returned on starting attempt when starting is in progress.
ErrAlreadyStarting = errors.New("already starting")
)

const (
// DefaultCheckDelay is a default delay for checking if a new instance is started successfully.
DefaultCheckDelay = 1 * time.Second
extraWaitingTime = 1 * time.Second
delayArgName = "--delay"
)

// Context describes data required for restarting visor.
type Context struct {
log logrus.FieldLogger
cmd *exec.Cmd
checkDelay time.Duration
isStarting int32
appendDelay bool // disabled in tests
}

// CaptureContext captures data required for restarting visor.
// Data used by CaptureContext must not be modified before,
// therefore calling CaptureContext immediately after starting executable is recommended.
func CaptureContext() *Context {
cmd := exec.Command(os.Args[0], os.Args[1:]...) // nolint:gosec

cmd.Stdout = os.Stdout
cmd.Stdin = os.Stdin
cmd.Stderr = os.Stderr
cmd.Env = os.Environ()

return &Context{
cmd: cmd,
checkDelay: DefaultCheckDelay,
appendDelay: true,
}
}

// RegisterLogger registers a logger instead of standard one.
func (c *Context) RegisterLogger(logger logrus.FieldLogger) {
if c != nil {
c.log = logger
}
}

// SetCheckDelay sets a check delay instead of standard one.
func (c *Context) SetCheckDelay(delay time.Duration) {
if c != nil {
c.checkDelay = delay
}
}

// Start starts a new executable using Context.
func (c *Context) Start() error {
if !atomic.CompareAndSwapInt32(&c.isStarting, 0, 1) {
return ErrAlreadyStarting
}

defer atomic.StoreInt32(&c.isStarting, 0)

errCh := c.startExec()

ticker := time.NewTicker(c.checkDelay)
defer ticker.Stop()

select {
case err := <-errCh:
c.errorLogger()("Failed to start new instance: %v", err)
return err
case <-ticker.C:
c.infoLogger()("New instance started successfully, exiting from the old one")
return nil
}
}

func (c *Context) startExec() chan error {
errCh := make(chan error, 1)

go func() {
defer close(errCh)

c.adjustArgs()

c.infoLogger()("Starting new instance of executable (args: %q)", c.cmd.Args)

if err := c.cmd.Start(); err != nil {
errCh <- err
return
}

if err := c.cmd.Wait(); err != nil {
errCh <- err
return
}
}()

return errCh
}

func (c *Context) adjustArgs() {
args := c.cmd.Args

i := 0
l := len(args)

for i < l {
if args[i] == delayArgName && i < len(args)-1 {
args = append(args[:i], args[i+2:]...)
l -= 2
} else {
i++
}
}

if c.appendDelay {
delay := c.checkDelay + extraWaitingTime
args = append(args, delayArgName, delay.String())
}

c.cmd.Args = args
}

func (c *Context) infoLogger() func(string, ...interface{}) {
if c.log != nil {
return c.log.Infof
}

logger := log.New(os.Stdout, "[INFO] ", log.LstdFlags)

return logger.Printf
}

func (c *Context) errorLogger() func(string, ...interface{}) {
if c.log != nil {
return c.log.Errorf
}

logger := log.New(os.Stdout, "[ERROR] ", log.LstdFlags)

return logger.Printf
}
83 changes: 83 additions & 0 deletions pkg/restart/restart_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package restart

import (
"os"
"os/exec"
"testing"
"time"

"github.com/SkycoinProject/skycoin/src/util/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCaptureContext(t *testing.T) {
cc := CaptureContext()

require.Equal(t, DefaultCheckDelay, cc.checkDelay)
require.Equal(t, os.Args, cc.cmd.Args)
require.Equal(t, os.Stdout, cc.cmd.Stdout)
require.Equal(t, os.Stdin, cc.cmd.Stdin)
require.Equal(t, os.Stderr, cc.cmd.Stderr)
require.Equal(t, os.Environ(), cc.cmd.Env)
require.Nil(t, cc.log)
}

func TestContext_RegisterLogger(t *testing.T) {
cc := CaptureContext()
require.Nil(t, cc.log)

logger := logging.MustGetLogger("test")
cc.RegisterLogger(logger)
require.Equal(t, logger, cc.log)
}

func TestContext_Start(t *testing.T) {
cc := CaptureContext()
assert.NotZero(t, len(cc.cmd.Args))

t.Run("executable started", func(t *testing.T) {
cmd := "touch"
path := "/tmp/test_restart"
cc.cmd = exec.Command(cmd, path) // nolint:gosec
cc.appendDelay = false

assert.NoError(t, cc.Start())
assert.NoError(t, os.Remove(path))
})

t.Run("bad args", func(t *testing.T) {
cmd := "bad_command"
cc.cmd = exec.Command(cmd) // nolint:gosec

// TODO(nkryuchkov): Check if it works on Linux and Windows, if not then change the error text.
assert.EqualError(t, cc.Start(), `exec: "bad_command": executable file not found in $PATH`)
})

t.Run("already restarting", func(t *testing.T) {
cmd := "touch"
path := "/tmp/test_restart"
cc.cmd = exec.Command(cmd, path) // nolint:gosec
cc.appendDelay = false

ch := make(chan error, 1)
go func() {
ch <- cc.Start()
}()

assert.NoError(t, cc.Start())
assert.Equal(t, ErrAlreadyStarting, <-ch)

assert.NoError(t, os.Remove(path))
})
}

func TestContext_SetCheckDelay(t *testing.T) {
cc := CaptureContext()
require.Equal(t, DefaultCheckDelay, cc.checkDelay)

const oneSecond = 1 * time.Second

cc.SetCheckDelay(oneSecond)
require.Equal(t, oneSecond, cc.checkDelay)
}
2 changes: 2 additions & 0 deletions pkg/visor/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ type Config struct {
Interfaces InterfaceConfig `json:"interfaces"`

AppServerSockFile string `json:"app_server_sock_file"`

RestartCheckDelay string `json:"restart_check_delay"`
}

// MessagingConfig returns config for dmsg client.
Expand Down
28 changes: 28 additions & 0 deletions pkg/visor/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"os"
"path/filepath"
"time"

Expand All @@ -29,6 +30,9 @@ var (

// ErrNotFound is returned when a requested resource is not found.
ErrNotFound = errors.New("not found")

// ErrMalformedRestartContext is returned when restart context is malformed.
ErrMalformedRestartContext = errors.New("restart context is malformed")
)

// RPC defines RPC methods for Node.
Expand Down Expand Up @@ -390,3 +394,27 @@ func (r *RPC) Loops(_ *struct{}, out *[]LoopInfo) error {
*out = loops
return nil
}

/*
<<< VISOR MANAGEMENT >>>
*/

const exitDelay = 100 * time.Millisecond

// Restart restarts visor.
func (r *RPC) Restart(_ *struct{}, _ *struct{}) (err error) {
defer func() {
if err == nil {
go func() {
time.Sleep(exitDelay)
os.Exit(0)
}()
}
}()

if r.node.restartCtx == nil {
return ErrMalformedRestartContext
}

return r.node.restartCtx.Start()
}
Loading

0 comments on commit 137c9c3

Please sign in to comment.