diff --git a/pkg/util/updater/updater.go b/pkg/util/updater/updater.go index 33b4c5d5b1..66df8365b7 100644 --- a/pkg/util/updater/updater.go +++ b/pkg/util/updater/updater.go @@ -14,7 +14,6 @@ import ( "path/filepath" "runtime" "strings" - "time" "unicode" "github.com/SkycoinProject/skycoin/src/util/logging" @@ -31,7 +30,9 @@ const ( checksumsFilename = "checksums.txt" checkSumLength = 64 permRWX = 0755 - exitDelay = 100 * time.Millisecond + oldSuffix = ".old" + visorBinary = "skywire-visor" + cliBinary = "skywire-cli" ) var ( @@ -57,6 +58,7 @@ func New(log *logging.Logger, restartCtx *restart.Context) *Updater { } // Update performs an update operation. +// NOTE: Update may call os.Exit. func (u *Updater) Update() error { u.log.Infof("Looking for updates") @@ -65,25 +67,64 @@ func (u *Updater) Update() error { return fmt.Errorf("failed to get last visor version: %w", err) } - u.log.Infof("Last visor version: %q", lastVersion.String()) + u.log.Infof("Last Skywire version: %q", lastVersion.String()) if !updateAvailable(lastVersion) { - u.log.Infof("You are using the latest version of visor") + u.log.Infof("You are using the latest version of Skywire") return nil } u.log.Infof("Update found, version: %q", lastVersion.String()) - path, err := u.download(lastVersion.String()) + downloadedVisorPath, err := u.download(visorBinary, lastVersion.String()) if err != nil { return err } - return u.start(path) + downloadedCLIPath, err := u.download(cliBinary, lastVersion.String()) + if err != nil { + return err + } + + currentVisorPath := u.restartCtx.CmdPath() + currentCLIPath := cliPath(currentVisorPath) + + oldCLIPath := downloadedCLIPath + oldSuffix + oldVisorPath := downloadedVisorPath + oldSuffix + + if err := u.updateBinary(downloadedCLIPath, currentCLIPath, oldCLIPath); err != nil { + return fmt.Errorf("failed to update %s binary: %w", cliBinary, err) + } + + if err := u.updateBinary(downloadedVisorPath, currentVisorPath, oldVisorPath); err != nil { + return fmt.Errorf("failed to update %s binary: %w", visorBinary, err) + } + + if err := u.restartCurrentProcess(); err != nil { + u.restore(currentVisorPath, oldVisorPath) + return err + } + + u.removeFiles(oldVisorPath, oldCLIPath) + + u.log.Infof("Exiting") + os.Exit(0) + + // Unreachable. + return nil +} + +// restore restores old binary file. +func (u *Updater) restore(currentBinaryPath string, toBeRemoved string) { + u.removeFiles(currentBinaryPath) + + if err := os.Rename(toBeRemoved, currentBinaryPath); err != nil { + u.log.Errorf("Failed to rename file %q to %q: %v", toBeRemoved, currentBinaryPath, err) + } } -func (u *Updater) download(version string) (string, error) { - checksumsURL := fileURL(version, checksumFile(version)) +func (u *Updater) download(binaryName, version string) (string, error) { + checksumsURL := fileURL(version, checksumFile(binaryName, version)) u.log.Infof("Checksum file URL: %q", checksumsURL) checksums, err := downloadChecksums(checksumsURL) @@ -93,7 +134,7 @@ func (u *Updater) download(version string) (string, error) { u.log.Infof("Checksums file downloaded") - binaryFilename := binaryFilename(version, runtime.GOOS, runtime.GOARCH) + binaryFilename := binaryFilename(binaryName, version, runtime.GOOS, runtime.GOARCH) u.log.Infof("Binary filename: %v", binaryFilename) checksum, err := getChecksum(checksums, binaryFilename) @@ -124,64 +165,26 @@ func (u *Updater) download(version string) (string, error) { return path, nil } -func (u *Updater) start(path string) error { - currentBinaryPath := u.restartCtx.CmdPath() - - toBeRemoved, err := u.updateBinary(path, currentBinaryPath) - if err != nil { - return fmt.Errorf("failed to update binary: %w", err) - } - - u.log.Infof("Need to remove file in %q", toBeRemoved) - - defer func() { - if err == nil { - go func() { - time.Sleep(exitDelay) - - u.log.Infof("Removing file in %q", toBeRemoved) - - if err := os.Remove(toBeRemoved); err != nil { - u.log.Errorf("Failed to remove file %q: %v", toBeRemoved, err) - } - - u.log.Infof("Exiting") - os.Exit(0) - }() - } - }() - +func (u *Updater) restartCurrentProcess() error { u.log.Infof("Starting new file instance") if err := u.restartCtx.Start(); err != nil { - u.log.Errorf("Failed to restart visor: %v", err) - - // Restore old binary file - if err := os.Remove(currentBinaryPath); err != nil { - u.log.Errorf("Failed to remove file %q: %v", currentBinaryPath, err) - } - - if err := os.Rename(toBeRemoved, currentBinaryPath); err != nil { - u.log.Errorf("Failed to rename file %q to %q: %v", toBeRemoved, currentBinaryPath, err) - } - - return fmt.Errorf("failed to restart visor: %w", err) + u.log.Errorf("Failed to start binary: %v", err) + return err } return nil } -func (u *Updater) updateBinary(downloadPath, currentPath string) (toBeRemoved string, err error) { - oldPath := currentPath + ".old" - +func (u *Updater) updateBinary(downloadPath, currentPath, oldPath string) error { if _, err := os.Stat(oldPath); err == nil { if err := os.Remove(oldPath); err != nil { - return "", err + return err } } if err := os.Rename(currentPath, oldPath); err != nil { - return "", err + return err } if err := os.Rename(downloadPath, currentPath); err != nil { @@ -190,10 +193,19 @@ func (u *Updater) updateBinary(downloadPath, currentPath string) (toBeRemoved st u.log.Errorf("Failed to rename file %q to %q: %v", oldPath, currentPath, err) } - return "", err + return err } - return oldPath, nil + return nil +} + +func (u *Updater) removeFiles(names ...string) { + for _, name := range names { + if err := os.Remove(name); err != nil { + u.log.Infof("Removing file %q", name) + u.log.Errorf("Failed to remove file %q: %v", name, err) + } + } } func isChecksumValid(filename, wantSum string) (bool, error) { @@ -300,16 +312,16 @@ func downloadFile(url, filename string) (path string, err error) { return path, nil } -func fileURL(version string, filename string) string { +func fileURL(version, filename string) string { return releaseURL + "/download/" + version + "/" + filename } -func checksumFile(version string) string { - return "skywire-visor-" + version + "-" + checksumsFilename +func checksumFile(binaryName, version string) string { + return binaryName + "-" + version + "-" + checksumsFilename } -func binaryFilename(version string, os, arch string) string { - return "skywire-visor-" + version + "-" + os + "-" + arch +func binaryFilename(binaryName, version, os, arch string) string { + return binaryName + "-" + version + "-" + os + "-" + arch } func updateAvailable(last *Version) bool { @@ -366,3 +378,7 @@ func extractLastVersion(buffer string) string { return versionWithRest[:idx] } + +func cliPath(visorPath string) string { + return filepath.Join(filepath.Dir(visorPath), cliBinary) +} diff --git a/pkg/util/updater/updater_test.go b/pkg/util/updater/updater_test.go index 75d108478b..ace98fa0b0 100644 --- a/pkg/util/updater/updater_test.go +++ b/pkg/util/updater/updater_test.go @@ -167,46 +167,76 @@ func Test_fileURL(t *testing.T) { func Test_checksumFile(t *testing.T) { tests := []struct { - name string - version string - want string + name string + binaryName string + version string + want string }{ { - name: "Case 1", - version: "1.2.3", - want: "skywire-visor-1.2.3-checksums.txt", + name: "Case 1", + binaryName: "skywire-visor", + version: "1.2.3", + want: "skywire-visor-1.2.3-checksums.txt", }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.want, checksumFile(tc.version)) + require.Equal(t, tc.want, checksumFile(tc.binaryName, tc.version)) }) } } func Test_binaryFilename(t *testing.T) { tests := []struct { - name string - version string - os string - arch string - want string + name string + binaryName string + version string + os string + arch string + want string }{ { - name: "Case 1", - version: "1.2.3", - os: "linux", - arch: "amd64", - want: "skywire-visor-1.2.3-linux-amd64", + name: "Case 1", + binaryName: "skywire-visor", + version: "1.2.3", + os: "linux", + arch: "amd64", + want: "skywire-visor-1.2.3-linux-amd64", }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - require.Equal(t, tc.want, binaryFilename(tc.version, tc.os, tc.arch)) + require.Equal(t, tc.want, binaryFilename(tc.binaryName, tc.version, tc.os, tc.arch)) + }) + } +} + +func Test_cliPath(t *testing.T) { + tests := []struct { + name string + visorPath string + want string + }{ + { + name: "Case 1", + visorPath: "/dir1/dir2/visor", + want: "/dir1/dir2/skywire-cli", + }, + { + name: "Case 2", + visorPath: "/dir3/dir4/../dir5/visor", + want: "/dir3/dir5/skywire-cli", + }, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + got := cliPath(tc.visorPath) + assert.Equal(t, tc.want, got) }) } }