diff --git a/ipinfo/cmd_download.go b/ipinfo/cmd_download.go index 6cb8e36e..fa1ffc21 100644 --- a/ipinfo/cmd_download.go +++ b/ipinfo/cmd_download.go @@ -2,6 +2,8 @@ package main import ( "compress/gzip" + "crypto/sha256" + "encoding/json" "errors" "fmt" "io" @@ -17,6 +19,14 @@ import ( const dbDownloadURL = "https://ipinfo.io/data/free/" +type ChecksumResponse struct { + Checksums struct { + MD5 string `json:"md5"` + SHA1 string `json:"sha1"` + SHA256 string `json:"sha256"` + } `json:"checksums"` +} + var completionsDownload = &complete.Command{ Flags: map[string]complete.Predictor{ "-c": predict.Nothing, @@ -140,6 +150,24 @@ func cmdDownload() error { return err } + // fetch checksums from API and check if they match. + checksumUrl := fmt.Sprintf("%s%s.%s/checksums?token=%s", dbDownloadURL, dbName, format, token) + checksumResponse, err := fetchChecksums(checksumUrl) + if err != nil { + return err + } + + // compute checksum of downloaded file. + localChecksum, err := computeSHA256(fileName) + if err != nil { + return err + } + + // compare checksums. + if localChecksum != checksumResponse.Checksums.SHA256 { + return errors.New("checksums do not match. File might be corrupted") + } + return nil } @@ -238,3 +266,38 @@ func unzipWrite(file *os.File, data io.Reader) error { return nil } + +func computeSHA256(filepath string) (string, error) { + file, err := os.Open(filepath) + if err != nil { + return "", err + } + defer file.Close() + + hasher := sha256.New() + if _, err := io.Copy(hasher, file); err != nil { + return "", err + } + + return fmt.Sprintf("%x", hasher.Sum(nil)), nil +} + +func fetchChecksums(url string) (*ChecksumResponse, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var checksumResponse ChecksumResponse + if err := json.Unmarshal(body, &checksumResponse); err != nil { + return nil, err + } + + return &checksumResponse, nil +}