Skip to content

Commit

Permalink
feat(plugin): Calculate SHA256 if it does not exist (#2935)
Browse files Browse the repository at this point in the history
* feat(plugin): Calculate SHA256 if it does not exist

* Add logging
  • Loading branch information
kyleconroy authored Nov 2, 2023
1 parent 4507ede commit f80cee1
Showing 1 changed file with 45 additions and 33 deletions.
78 changes: 45 additions & 33 deletions internal/ext/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -52,20 +53,26 @@ func cacheDir() (string, error) {
var flight singleflight.Group

// Verify the provided sha256 is valid.
func (r *Runner) parseChecksum() (string, error) {
if r.SHA256 == "" {
return "", fmt.Errorf("missing SHA-256 checksum")
func (r *Runner) getChecksum(ctx context.Context) (string, error) {
if r.SHA256 != "" {
return r.SHA256, nil
}
return r.SHA256, nil
// TODO: Add a log line here about something
_, sum, err := r.fetch(ctx, r.URL)
if err != nil {
return "", err
}
slog.Warn("fetching WASM binary to calculate sha256. Set this value in sqlc.yaml to prevent unneeded work", "sha256", sum)
return sum, nil
}

func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
expected, err := r.parseChecksum()
expected, err := r.getChecksum(ctx)
if err != nil {
return nil, err
}
value, err, _ := flight.Do(expected, func() (interface{}, error) {
return r.loadSerializedModule(ctx, engine)
return r.loadSerializedModule(ctx, engine, expected)
})
if err != nil {
return nil, err
Expand All @@ -77,17 +84,13 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
return wasmtime.NewModuleDeserialize(engine, data)
}

func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine) ([]byte, error) {
expected, err := r.parseChecksum()
if err != nil {
return nil, err
}
func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
cacheDir, err := cache.PluginsDir()
if err != nil {
return nil, err
}

pluginDir := filepath.Join(cacheDir, expected)
pluginDir := filepath.Join(cacheDir, expectedSha)
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
modPath := filepath.Join(pluginDir, modName)
_, staterr := os.Stat(modPath)
Expand All @@ -99,7 +102,7 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
return data, nil
}

wmod, err := r.loadWASM(ctx, cacheDir, expected)
wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
if err != nil {
return nil, err
}
Expand All @@ -126,53 +129,62 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
return out, nil
}

func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
pluginDir := filepath.Join(cache, expected)
pluginPath := filepath.Join(pluginDir, "plugin.wasm")
_, staterr := os.Stat(pluginPath)

func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
var body io.ReadCloser

switch {
case staterr == nil:
file, err := os.Open(pluginPath)
if err != nil {
return nil, fmt.Errorf("os.Open: %s %w", pluginPath, err)
}
body = file

case strings.HasPrefix(r.URL, "file://"):
file, err := os.Open(strings.TrimPrefix(r.URL, "file://"))
case strings.HasPrefix(uri, "file://"):
file, err := os.Open(strings.TrimPrefix(uri, "file://"))
if err != nil {
return nil, fmt.Errorf("os.Open: %s %w", r.URL, err)
return nil, "", fmt.Errorf("os.Open: %s %w", uri, err)
}
body = file

case strings.HasPrefix(r.URL, "https://"):
req, err := http.NewRequestWithContext(ctx, "GET", r.URL, nil)
case strings.HasPrefix(uri, "https://"):
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
return nil, fmt.Errorf("http.Get: %s %w", r.URL, err)
return nil, "", fmt.Errorf("http.Get: %s %w", uri, err)
}
req.Header.Set("User-Agent", fmt.Sprintf("sqlc/%s Go/%s (%s %s)", info.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("http.Get: %s %w", r.URL, err)
return nil, "", fmt.Errorf("http.Get: %s %w", r.URL, err)
}
body = resp.Body

default:
return nil, fmt.Errorf("unknown scheme: %s", r.URL)
return nil, "", fmt.Errorf("unknown scheme: %s", r.URL)
}

defer body.Close()

wmod, err := io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("readall: %w", err)
return nil, "", fmt.Errorf("readall: %w", err)
}

sum := sha256.Sum256(wmod)
actual := fmt.Sprintf("%x", sum)

return wmod, actual, nil
}

func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
pluginDir := filepath.Join(cache, expected)
pluginPath := filepath.Join(pluginDir, "plugin.wasm")
_, staterr := os.Stat(pluginPath)

uri := r.URL
if staterr == nil {
uri = "file://" + pluginPath
}

wmod, actual, err := r.fetch(ctx, uri)
if err != nil {
return nil, err
}

if expected != actual {
return nil, fmt.Errorf("invalid checksum: expected %s, got %s", expected, actual)
}
Expand Down

0 comments on commit f80cee1

Please sign in to comment.