diff --git a/cmd/skywire-visor/commands/root.go b/cmd/skywire-visor/commands/root.go index 1b49da606f..ed47be7292 100644 --- a/cmd/skywire-visor/commands/root.go +++ b/cmd/skywire-visor/commands/root.go @@ -23,6 +23,7 @@ import ( "github.com/spf13/cobra" "github.com/SkycoinProject/skywire-mainnet/internal/utclient" + "github.com/SkycoinProject/skywire-mainnet/pkg/restart" "github.com/SkycoinProject/skywire-mainnet/pkg/util/pathutil" "github.com/SkycoinProject/skywire-mainnet/pkg/visor" ) @@ -39,6 +40,7 @@ type runCfg struct { cfgFromStdin bool profileMode string port string + startDelay string args []string profileStop func() @@ -46,6 +48,7 @@ type runCfg struct { masterLogger *logging.MasterLogger conf visor.Config node *visor.Node + restartCtx *restart.Context } var cfg *runCfg @@ -73,6 +76,9 @@ func init() { rootCmd.Flags().BoolVarP(&cfg.cfgFromStdin, "stdin", "i", false, "read config from STDIN") rootCmd.Flags().StringVarP(&cfg.profileMode, "profile", "p", "none", "enable profiling with pprof. Mode: none or one of: [cpu, mem, mutex, block, trace, http]") rootCmd.Flags().StringVarP(&cfg.port, "port", "", "6060", "port for http-mode of pprof") + rootCmd.Flags().StringVarP(&cfg.startDelay, "delay", "", "0ns", "delay before visor start") + + cfg.restartCtx = restart.CaptureContext() } // Execute executes root CLI command. @@ -148,7 +154,19 @@ func (cfg *runCfg) readConfig() *runCfg { } func (cfg *runCfg) runNode() *runCfg { - node, err := visor.NewNode(&cfg.conf, cfg.masterLogger) + startDelay, err := time.ParseDuration(cfg.startDelay) + if err != nil { + cfg.logger.Warnf("Using no visor start delay due to parsing failure: %v", err) + startDelay = time.Duration(0) + } + + if startDelay != 0 { + cfg.logger.Infof("Visor start delay is %v, waiting...", startDelay) + } + + time.Sleep(startDelay) + + node, err := visor.NewNode(&cfg.conf, cfg.masterLogger, cfg.restartCtx) if err != nil { cfg.logger.Fatal("Failed to initialize node: ", err) } @@ -181,7 +199,9 @@ func (cfg *runCfg) runNode() *runCfg { if cfg.conf.ShutdownTimeout == 0 { cfg.conf.ShutdownTimeout = defaultShutdownTimeout } + cfg.node = node + return cfg } diff --git a/pkg/hypervisor/hypervisor.go b/pkg/hypervisor/hypervisor.go index c8134db15d..30863609b3 100644 --- a/pkg/hypervisor/hypervisor.go +++ b/pkg/hypervisor/hypervisor.go @@ -151,6 +151,7 @@ func (m *Node) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.Put("/nodes/{pk}/routes/{rid}", m.putRoute()) r.Delete("/nodes/{pk}/routes/{rid}", m.deleteRoute()) r.Get("/nodes/{pk}/loops", m.getLoops()) + r.Get("/nodes/{pk}/restart", m.restart()) }) }) r.ServeHTTP(w, req) @@ -569,6 +570,18 @@ func (m *Node) getLoops() http.HandlerFunc { }) } +// NOTE: Reply comes with a delay, because of check if new executable is started successfully. +func (m *Node) restart() http.HandlerFunc { + return m.withCtx(m.nodeCtx, func(w http.ResponseWriter, r *http.Request, ctx *httpCtx) { + if err := ctx.RPC.Restart(); err != nil { + httputil.WriteJSON(w, r, http.StatusInternalServerError, err) + return + } + + httputil.WriteJSON(w, r, http.StatusOK, true) + }) +} + /* <<< Helper functions >>> */ diff --git a/pkg/restart/restart.go b/pkg/restart/restart.go new file mode 100644 index 0000000000..0c485ebb8d --- /dev/null +++ b/pkg/restart/restart.go @@ -0,0 +1,155 @@ +package restart + +import ( + "errors" + "log" + "os" + "os/exec" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" +) + +var ( + // ErrAlreadyStarting is returned on starting attempt when starting is in progress. + ErrAlreadyStarting = errors.New("already starting") +) + +const ( + // DefaultCheckDelay is a default delay for checking if a new instance is started successfully. + DefaultCheckDelay = 1 * time.Second + extraWaitingTime = 1 * time.Second + delayArgName = "--delay" +) + +// Context describes data required for restarting visor. +type Context struct { + log logrus.FieldLogger + cmd *exec.Cmd + checkDelay time.Duration + isStarting int32 + appendDelay bool // disabled in tests +} + +// CaptureContext captures data required for restarting visor. +// Data used by CaptureContext must not be modified before, +// therefore calling CaptureContext immediately after starting executable is recommended. +func CaptureContext() *Context { + cmd := exec.Command(os.Args[0], os.Args[1:]...) // nolint:gosec + + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + cmd.Stderr = os.Stderr + cmd.Env = os.Environ() + + return &Context{ + cmd: cmd, + checkDelay: DefaultCheckDelay, + appendDelay: true, + } +} + +// RegisterLogger registers a logger instead of standard one. +func (c *Context) RegisterLogger(logger logrus.FieldLogger) { + if c != nil { + c.log = logger + } +} + +// SetCheckDelay sets a check delay instead of standard one. +func (c *Context) SetCheckDelay(delay time.Duration) { + if c != nil { + c.checkDelay = delay + } +} + +// Start starts a new executable using Context. +func (c *Context) Start() error { + if !atomic.CompareAndSwapInt32(&c.isStarting, 0, 1) { + return ErrAlreadyStarting + } + + defer atomic.StoreInt32(&c.isStarting, 0) + + errCh := c.startExec() + + ticker := time.NewTicker(c.checkDelay) + defer ticker.Stop() + + select { + case err := <-errCh: + c.errorLogger()("Failed to start new instance: %v", err) + return err + case <-ticker.C: + c.infoLogger()("New instance started successfully, exiting from the old one") + return nil + } +} + +func (c *Context) startExec() chan error { + errCh := make(chan error, 1) + + go func() { + defer close(errCh) + + c.adjustArgs() + + c.infoLogger()("Starting new instance of executable (args: %q)", c.cmd.Args) + + if err := c.cmd.Start(); err != nil { + errCh <- err + return + } + + if err := c.cmd.Wait(); err != nil { + errCh <- err + return + } + }() + + return errCh +} + +func (c *Context) adjustArgs() { + args := c.cmd.Args + + i := 0 + l := len(args) + + for i < l { + if args[i] == delayArgName && i < len(args)-1 { + args = append(args[:i], args[i+2:]...) + l -= 2 + } else { + i++ + } + } + + if c.appendDelay { + delay := c.checkDelay + extraWaitingTime + args = append(args, delayArgName, delay.String()) + } + + c.cmd.Args = args +} + +func (c *Context) infoLogger() func(string, ...interface{}) { + if c.log != nil { + return c.log.Infof + } + + logger := log.New(os.Stdout, "[INFO] ", log.LstdFlags) + + return logger.Printf +} + +func (c *Context) errorLogger() func(string, ...interface{}) { + if c.log != nil { + return c.log.Errorf + } + + logger := log.New(os.Stdout, "[ERROR] ", log.LstdFlags) + + return logger.Printf +} diff --git a/pkg/restart/restart_test.go b/pkg/restart/restart_test.go new file mode 100644 index 0000000000..bb98253490 --- /dev/null +++ b/pkg/restart/restart_test.go @@ -0,0 +1,83 @@ +package restart + +import ( + "os" + "os/exec" + "testing" + "time" + + "github.com/SkycoinProject/skycoin/src/util/logging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCaptureContext(t *testing.T) { + cc := CaptureContext() + + require.Equal(t, DefaultCheckDelay, cc.checkDelay) + require.Equal(t, os.Args, cc.cmd.Args) + require.Equal(t, os.Stdout, cc.cmd.Stdout) + require.Equal(t, os.Stdin, cc.cmd.Stdin) + require.Equal(t, os.Stderr, cc.cmd.Stderr) + require.Equal(t, os.Environ(), cc.cmd.Env) + require.Nil(t, cc.log) +} + +func TestContext_RegisterLogger(t *testing.T) { + cc := CaptureContext() + require.Nil(t, cc.log) + + logger := logging.MustGetLogger("test") + cc.RegisterLogger(logger) + require.Equal(t, logger, cc.log) +} + +func TestContext_Start(t *testing.T) { + cc := CaptureContext() + assert.NotZero(t, len(cc.cmd.Args)) + + t.Run("executable started", func(t *testing.T) { + cmd := "touch" + path := "/tmp/test_restart" + cc.cmd = exec.Command(cmd, path) // nolint:gosec + cc.appendDelay = false + + assert.NoError(t, cc.Start()) + assert.NoError(t, os.Remove(path)) + }) + + t.Run("bad args", func(t *testing.T) { + cmd := "bad_command" + cc.cmd = exec.Command(cmd) // nolint:gosec + + // TODO(nkryuchkov): Check if it works on Linux and Windows, if not then change the error text. + assert.EqualError(t, cc.Start(), `exec: "bad_command": executable file not found in $PATH`) + }) + + t.Run("already restarting", func(t *testing.T) { + cmd := "touch" + path := "/tmp/test_restart" + cc.cmd = exec.Command(cmd, path) // nolint:gosec + cc.appendDelay = false + + ch := make(chan error, 1) + go func() { + ch <- cc.Start() + }() + + assert.NoError(t, cc.Start()) + assert.Equal(t, ErrAlreadyStarting, <-ch) + + assert.NoError(t, os.Remove(path)) + }) +} + +func TestContext_SetCheckDelay(t *testing.T) { + cc := CaptureContext() + require.Equal(t, DefaultCheckDelay, cc.checkDelay) + + const oneSecond = 1 * time.Second + + cc.SetCheckDelay(oneSecond) + require.Equal(t, oneSecond, cc.checkDelay) +} diff --git a/pkg/visor/config.go b/pkg/visor/config.go index 69e54d8083..c4c1bbbb37 100644 --- a/pkg/visor/config.go +++ b/pkg/visor/config.go @@ -73,6 +73,8 @@ type Config struct { Interfaces InterfaceConfig `json:"interfaces"` AppServerSockFile string `json:"app_server_sock_file"` + + RestartCheckDelay string `json:"restart_check_delay"` } // MessagingConfig returns config for dmsg client. diff --git a/pkg/visor/rpc.go b/pkg/visor/rpc.go index 4b77005df4..7ddd95834a 100644 --- a/pkg/visor/rpc.go +++ b/pkg/visor/rpc.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "os" "path/filepath" "time" @@ -29,6 +30,9 @@ var ( // ErrNotFound is returned when a requested resource is not found. ErrNotFound = errors.New("not found") + + // ErrMalformedRestartContext is returned when restart context is malformed. + ErrMalformedRestartContext = errors.New("restart context is malformed") ) // RPC defines RPC methods for Node. @@ -390,3 +394,27 @@ func (r *RPC) Loops(_ *struct{}, out *[]LoopInfo) error { *out = loops return nil } + +/* + <<< VISOR MANAGEMENT >>> +*/ + +const exitDelay = 100 * time.Millisecond + +// Restart restarts visor. +func (r *RPC) Restart(_ *struct{}, _ *struct{}) (err error) { + defer func() { + if err == nil { + go func() { + time.Sleep(exitDelay) + os.Exit(0) + }() + } + }() + + if r.node.restartCtx == nil { + return ErrMalformedRestartContext + } + + return r.node.restartCtx.Start() +} diff --git a/pkg/visor/rpc_client.go b/pkg/visor/rpc_client.go index 219734fcd9..92882bacb0 100644 --- a/pkg/visor/rpc_client.go +++ b/pkg/visor/rpc_client.go @@ -9,15 +9,14 @@ import ( "sync" "time" - "github.com/SkycoinProject/skywire-mainnet/pkg/app" - "github.com/SkycoinProject/skywire-mainnet/pkg/router" - "github.com/SkycoinProject/skywire-mainnet/pkg/snet/snettest" - "github.com/SkycoinProject/dmsg/cipher" "github.com/SkycoinProject/skycoin/src/util/logging" "github.com/google/uuid" + "github.com/SkycoinProject/skywire-mainnet/pkg/app" + "github.com/SkycoinProject/skywire-mainnet/pkg/router" "github.com/SkycoinProject/skywire-mainnet/pkg/routing" + "github.com/SkycoinProject/skywire-mainnet/pkg/snet/snettest" "github.com/SkycoinProject/skywire-mainnet/pkg/transport" ) @@ -50,6 +49,8 @@ type RPCClient interface { RemoveRoutingRule(key routing.RouteID) error Loops() ([]LoopInfo, error) + + Restart() error } // RPCClient provides methods to call an RPC Server. @@ -222,6 +223,11 @@ func (rc *rpcClient) Loops() ([]LoopInfo, error) { return loops, err } +// Restart calls Restart. +func (rc *rpcClient) Restart() error { + return rc.Call("Restart", &struct{}{}, &struct{}{}) +} + // MockRPCClient mocks RPCClient. type mockRPCClient struct { startedAt time.Time @@ -544,3 +550,8 @@ func (mc *mockRPCClient) Loops() ([]LoopInfo, error) { return loops, nil } + +// Restart implements RPCClient. +func (mc *mockRPCClient) Restart() error { + return nil +} diff --git a/pkg/visor/visor.go b/pkg/visor/visor.go index b51438fc5e..c316a51526 100644 --- a/pkg/visor/visor.go +++ b/pkg/visor/visor.go @@ -26,6 +26,7 @@ import ( "github.com/SkycoinProject/skywire-mainnet/pkg/app/appnet" "github.com/SkycoinProject/skywire-mainnet/pkg/app/appserver" "github.com/SkycoinProject/skywire-mainnet/pkg/dmsgpty" + "github.com/SkycoinProject/skywire-mainnet/pkg/restart" "github.com/SkycoinProject/skywire-mainnet/pkg/routefinder/rfclient" "github.com/SkycoinProject/skywire-mainnet/pkg/router" "github.com/SkycoinProject/skywire-mainnet/pkg/routing" @@ -82,7 +83,8 @@ type Node struct { localPath string appsConf []AppConfig - startedAt time.Time + startedAt time.Time + restartCtx *restart.Context pidMu sync.Mutex @@ -93,7 +95,7 @@ type Node struct { } // NewNode constructs new Node. -func NewNode(config *Config, masterLogger *logging.MasterLogger) (*Node, error) { +func NewNode(config *Config, masterLogger *logging.MasterLogger, restartCtx *restart.Context) (*Node, error) { ctx := context.Background() node := &Node{ @@ -104,6 +106,15 @@ func NewNode(config *Config, masterLogger *logging.MasterLogger) (*Node, error) node.Logger = masterLogger node.logger = node.Logger.PackageLogger("skywire") + restartCheckDelay, err := time.ParseDuration(config.RestartCheckDelay) + if err == nil { + restartCtx.SetCheckDelay(restartCheckDelay) + } + + restartCtx.RegisterLogger(node.logger) + + node.restartCtx = restartCtx + pk := config.Node.StaticPubKey sk := config.Node.StaticSecKey