From ceeaad49891fe299b3cfb47be0aebfc81b5378fa Mon Sep 17 00:00:00 2001 From: Gareth George Date: Thu, 1 May 2025 21:38:05 -0700 Subject: [PATCH] fix: update restic install script to allow newer versions of restic if present in the path --- internal/resticinstaller/resticinstaller.go | 150 +++++--------------- internal/resticinstaller/urls.go | 26 ++++ internal/resticinstaller/version.go | 75 ++++++++++ internal/resticinstaller/version_test.go | 59 ++++++++ 4 files changed, 198 insertions(+), 112 deletions(-) create mode 100644 internal/resticinstaller/urls.go create mode 100644 internal/resticinstaller/version.go create mode 100644 internal/resticinstaller/version_test.go diff --git a/internal/resticinstaller/resticinstaller.go b/internal/resticinstaller/resticinstaller.go index 36c253aa..3a50acb9 100644 --- a/internal/resticinstaller/resticinstaller.go +++ b/internal/resticinstaller/resticinstaller.go @@ -7,7 +7,6 @@ import ( "os/exec" "path" "path/filepath" - "regexp" "runtime" "strings" "sync" @@ -24,53 +23,13 @@ var ( var ( RequiredResticVersion = "0.18.0" + requiredVersionSemver = mustParseSemVer(RequiredResticVersion) + tryFindRestic sync.Once findResticErr error foundResticPath string ) -func getResticVersion(binary string) (string, error) { - cmd := exec.Command(binary, "version") - out, err := cmd.Output() - if err != nil { - return "", fmt.Errorf("exec %v: %w", cmd.String(), err) - } - match := regexp.MustCompile(`restic\s+((\d+\.\d+\.\d+))`).FindSubmatch(out) - if len(match) < 2 { - return "", fmt.Errorf("could not find restic version in output: %s", out) - } - return string(match[1]), nil -} - -func assertResticVersion(binary string) error { - if version, err := getResticVersion(binary); err != nil { - return fmt.Errorf("determine restic version: %w", err) - } else if version != RequiredResticVersion { - return fmt.Errorf("want restic %v but found version %v", RequiredResticVersion, version) - } - return nil -} - -func resticDownloadURL(version string) string { - if runtime.GOOS == "windows" { - // restic is only built for 386 and amd64 on Windows, default to amd64 for other platforms (e.g. arm64.) - arch := "amd64" - if runtime.GOARCH == "386" || runtime.GOARCH == "amd64" { - arch = runtime.GOARCH - } - return fmt.Sprintf("https://github.com/restic/restic/releases/download/v%v/restic_%v_windows_%v.zip", version, version, arch) - } - return fmt.Sprintf("https://github.com/restic/restic/releases/download/v%v/restic_%v_%v_%v.bz2", version, version, runtime.GOOS, runtime.GOARCH) -} - -func hashDownloadURL(version string) string { - return fmt.Sprintf("https://github.com/restic/restic/releases/download/v%v/SHA256SUMS", version) -} - -func sigDownloadURL(version string) string { - return fmt.Sprintf("https://github.com/restic/restic/releases/download/v%v/SHA256SUMS.asc", version) -} - func verify(sha256 string) error { sha256sums, err := getURL(hashDownloadURL(RequiredResticVersion)) if err != nil { @@ -94,79 +53,57 @@ func verify(sha256 string) error { return nil } -func installResticIfNotExists(resticInstallPath string) error { - if _, err := os.Stat(resticInstallPath); err == nil { - // file is now installed, probably by another process. We can return. - return nil - } - - if err := os.MkdirAll(path.Dir(resticInstallPath), 0755); err != nil { - return fmt.Errorf("create restic install directory %v: %w", path.Dir(resticInstallPath), err) - } - - hash, err := downloadFile(resticDownloadURL(RequiredResticVersion), resticInstallPath+".tmp") +func installRestic(targetPath string) error { + sha256sum, err := downloadFile(resticDownloadURL(RequiredResticVersion), targetPath+".tmp") if err != nil { - return err + return fmt.Errorf("downloading: %w", err) } - if err := verify(hash); err != nil { - os.Remove(resticInstallPath) // try to remove the bad binary. - return fmt.Errorf("failed to verify the authenticity of the downloaded restic binary: %v", err) + if err := verify(sha256sum); err != nil { + return fmt.Errorf("verifying: %w", err) } - if err := os.Chmod(resticInstallPath+".tmp", 0755); err != nil { - return fmt.Errorf("chmod executable %v: %w", resticInstallPath, err) + if err := os.Rename(targetPath+".tmp", targetPath); err != nil { + return fmt.Errorf("renaming %v: %w", targetPath, err) } - if err := os.Rename(resticInstallPath+".tmp", resticInstallPath); err != nil { - return fmt.Errorf("rename %v.tmp to %v: %w", resticInstallPath, resticInstallPath, err) + if err := os.Chmod(targetPath, 0755); err != nil { + return fmt.Errorf("chmod executable %v: %w", targetPath, err) } return nil } -func removeOldVersions(installDir string) { - files, err := os.ReadDir(installDir) - if err != nil { - zap.S().Errorf("remove old restic versions: read dir %v: %v", installDir, err) - return +func findOrDownloadRestic(installPath string) error { + if err := assertResticVersion(installPath, true /* strict */); err == nil { + return nil } - for _, file := range files { - if !strings.HasPrefix(file.Name(), "restic-") || strings.Contains(file.Name(), RequiredResticVersion) { - continue - } - - if err := os.Remove(path.Join(installDir, file.Name())); err != nil { - zap.S().Errorf("remove old restic version %v: %v", file.Name(), err) - } + lock := flock.New(filepath.Join(filepath.Dir(installPath), "install.lock")) + if err := lock.Lock(); err != nil { + return fmt.Errorf("acquire lock on restic install dir %v: %v", lock.Path(), err) } + defer lock.Unlock() + + if err := assertResticVersion(installPath, true /* strict */); err == nil { + return nil + } else { + zap.S().Infof("restic binary %v failed version validation: %v", installPath, err) + } + + zap.S().Infof("installing restic to %v", installPath) + if err := installRestic(installPath); err != nil { + return fmt.Errorf("install restic: %w", err) + } + + return nil } -func installResticHelper(resticInstallPath string) { - if _, err := os.Stat(resticInstallPath); err == nil { - zap.S().Infof("replacing restic binary in data dir due to failed check: %v", err) - if err := os.Remove(resticInstallPath); err != nil { - zap.S().Errorf("failed to remove old restic binary %v: %v", resticInstallPath, err) - } - } - - zap.S().Infof("downloading restic %v to %v...", RequiredResticVersion, resticInstallPath) - if err := installResticIfNotExists(resticInstallPath); err != nil { - zap.S().Errorf("failed to install restic %v: %v", RequiredResticVersion, err) - return - } - zap.S().Infof("installed restic %v", RequiredResticVersion) - - // TODO: this check is no longer needed, remove it after a few releases. - removeOldVersions(path.Dir(resticInstallPath)) -} - -func tryFindOrInstall() (string, error) { +func findHelper() (string, error) { // Check if restic is provided. resticBinOverride := env.ResticBinPath() if resticBinOverride != "" { - if err := assertResticVersion(resticBinOverride); err != nil { + if err := assertResticVersion(resticBinOverride, false /* strict */); err != nil { zap.S().Warnf("restic binary %q may not be supported by backrest: %v", resticBinOverride, err) } @@ -181,7 +118,7 @@ func tryFindOrInstall() (string, error) { // Search the PATH for the specific restic version. if binPath, err := exec.LookPath("restic"); err == nil { - if err := assertResticVersion(binPath); err == nil { + if err := assertResticVersion(binPath, false /* strict */); err == nil { zap.S().Infof("restic binary %q in $PATH matches required version %v, it will be used for backrest commands", binPath, RequiredResticVersion) return binPath, nil } else { @@ -201,29 +138,18 @@ func tryFindOrInstall() (string, error) { return "", fmt.Errorf("create restic install directory %v: %w", path.Dir(resticInstallPath), err) } - // Install restic if not found OR if the version is not the required version - if err := assertResticVersion(resticInstallPath); err != nil { - lock := flock.New(filepath.Join(filepath.Dir(resticInstallPath), "install.lock")) - if err := lock.Lock(); err != nil { - return "", fmt.Errorf("acquire lock on restic install dir %v: %v", lock.Path(), err) - } - defer lock.Unlock() - - // Check again after acquiring the lock. - if err := assertResticVersion(resticInstallPath); err != nil { - zap.S().Errorf("could not verify version of binary %v: %v", resticInstallPath, err) - installResticHelper(resticInstallPath) - } + if err := findOrDownloadRestic(resticInstallPath); err != nil { + return "", fmt.Errorf("find or download restic: %w", err) } - zap.S().Infof("restic binary %v in data dir will be used as no system install matching required version %v is found", resticInstallPath, RequiredResticVersion) + zap.S().Infof("restic binary %q in data dir matches required version %v, it will be used for backrest commands", resticInstallPath, RequiredResticVersion) return resticInstallPath, nil } // FindOrInstallResticBinary first tries to find the restic binary if provided as an environment variable. Otherwise it downloads restic if not already installed. func FindOrInstallResticBinary() (string, error) { tryFindRestic.Do(func() { - foundResticPath, findResticErr = tryFindOrInstall() + foundResticPath, findResticErr = findHelper() }) if findResticErr != nil { diff --git a/internal/resticinstaller/urls.go b/internal/resticinstaller/urls.go new file mode 100644 index 00000000..9a221ac0 --- /dev/null +++ b/internal/resticinstaller/urls.go @@ -0,0 +1,26 @@ +package resticinstaller + +import ( + "fmt" + "runtime" +) + +func resticDownloadURL(version string) string { + if runtime.GOOS == "windows" { + // restic is only built for 386 and amd64 on Windows, default to amd64 for other platforms (e.g. arm64.) + arch := "amd64" + if runtime.GOARCH == "386" || runtime.GOARCH == "amd64" { + arch = runtime.GOARCH + } + return fmt.Sprintf("https://github.com/restic/restic/releases/download/v%v/restic_%v_windows_%v.zip", version, version, arch) + } + return fmt.Sprintf("https://github.com/restic/restic/releases/download/v%v/restic_%v_%v_%v.bz2", version, version, runtime.GOOS, runtime.GOARCH) +} + +func hashDownloadURL(version string) string { + return fmt.Sprintf("https://github.com/restic/restic/releases/download/v%v/SHA256SUMS", version) +} + +func sigDownloadURL(version string) string { + return fmt.Sprintf("https://github.com/restic/restic/releases/download/v%v/SHA256SUMS.asc", version) +} diff --git a/internal/resticinstaller/version.go b/internal/resticinstaller/version.go new file mode 100644 index 00000000..865f9008 --- /dev/null +++ b/internal/resticinstaller/version.go @@ -0,0 +1,75 @@ +package resticinstaller + +import ( + "errors" + "fmt" + "os" + "os/exec" + "regexp" + + "go.uber.org/zap" +) + +func getResticVersion(binary string) (string, error) { + cmd := exec.Command(binary, "version") + out, err := cmd.Output() + // check if error is a binary not found error + if err != nil { + if errors.Is(err, exec.ErrNotFound) { + return "", ErrResticNotFound + } + return "", fmt.Errorf("exec %v: %w", cmd.String(), err) + } + match := regexp.MustCompile(`restic\s+((\d+\.\d+\.\d+))`).FindSubmatch(out) + if len(match) < 2 { + return "", fmt.Errorf("could not find restic version in output: %s", out) + } + return string(match[1]), nil +} + +func assertResticVersion(binary string, strict bool) error { + if _, err := os.Stat(binary); err != nil { + return fmt.Errorf("check if restic binary exists: %w", err) + } + + if version, err := getResticVersion(binary); err != nil { + return fmt.Errorf("determine restic version: %w", err) + } else { + cmp := compareSemVer(mustParseSemVer(version), requiredVersionSemver) + if cmp < 0 { + return fmt.Errorf("restic version %v is less than required version %v", version, RequiredResticVersion) + } else if cmp > 0 && strict { + return fmt.Errorf("restic version %v is newer than required version %v, it may not be supported by backrest", version, RequiredResticVersion) + } else if cmp > 0 { + zap.S().Warnf("restic version %v is newer than required version %v, it may not be supported by backrest", version, RequiredResticVersion) + } + } + return nil +} + +func parseSemVer(version string) ([3]int, error) { + var major, minor, patch int + _, err := fmt.Sscanf(version, "%d.%d.%d", &major, &minor, &patch) + if err != nil { + return [3]int{}, fmt.Errorf("invalid semantic version format: %w", err) + } + return [3]int{major, minor, patch}, nil +} + +func mustParseSemVer(version string) [3]int { + v, err := parseSemVer(version) + if err != nil { + panic(err) + } + return v +} + +func compareSemVer(v1 [3]int, v2 [3]int) int { + if v1[0] != v2[0] { + return v1[0] - v2[0] + } + if v1[1] != v2[1] { + return v1[1] - v2[1] + } + return v1[2] - v2[2] +} diff --git a/internal/resticinstaller/version_test.go b/internal/resticinstaller/version_test.go new file mode 100644 index 00000000..d2d000ab --- /dev/null +++ b/internal/resticinstaller/version_test.go @@ -0,0 +1,59 @@ +package resticinstaller + +import "testing" + +func TestParseSemVer(t *testing.T) { + testCases := []struct { + name string + input string + want [3]int + wantErr bool + }{ + {"Valid version", "0.18.0", [3]int{0, 18, 0}, false}, + {"Invalid version", "1.2", [3]int{}, true}, + {"Empty string", "", [3]int{}, true}, + {"Non-numeric version", "a.b.c", [3]int{}, true}, + {"Version with extra parts", "1.2.3.4", [3]int{1, 2, 3}, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := parseSemVer(tc.input) + if (err != nil) != tc.wantErr { + t.Errorf("parseSemVer(%q) error = %v, wantErr %v", tc.input, err, tc.wantErr) + return + } + if got != tc.want { + t.Errorf("parseSemVer(%q) = %v, want %v", tc.input, got, tc.want) + } + }) + } +} + +func TestCompareSemVer(t *testing.T) { + testCases := []struct { + name string + v1 [3]int + v2 [3]int + want int // 1 if v1 > v2, -1 if v1 < v2, 0 if v1 == v2 + }{ + {"Equal versions", [3]int{1, 2, 3}, [3]int{1, 2, 3}, 0}, + {"v1 major greater", [3]int{2, 0, 0}, [3]int{1, 9, 9}, 1}, + {"v1 major smaller", [3]int{1, 9, 9}, [3]int{2, 0, 0}, -1}, + {"v1 minor greater", [3]int{1, 3, 0}, [3]int{1, 2, 9}, 1}, + {"v1 minor smaller", [3]int{1, 2, 9}, [3]int{1, 3, 0}, -1}, + {"v1 patch greater", [3]int{1, 2, 4}, [3]int{1, 2, 3}, 1}, + {"v1 patch smaller", [3]int{1, 2, 3}, [3]int{1, 2, 4}, -1}, + {"Zero versions equal", [3]int{0, 0, 0}, [3]int{0, 0, 0}, 0}, + {"Mixed zero versions", [3]int{0, 1, 0}, [3]int{0, 0, 9}, 1}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := compareSemVer(tc.v1, tc.v2) + if got != tc.want { + t.Errorf("compareSemVer(%v, %v) = %d, want %d", tc.v1, tc.v2, got, tc.want) + } + }) + } +}