diff --git a/consulfs/consul.go b/consulfs/consul.go index 790d64f3..f513574f 100644 --- a/consulfs/consul.go +++ b/consulfs/consul.go @@ -182,7 +182,7 @@ func (f *consulFS) Open(name string) (fs.File, error) { return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrInvalid} } - u, err := subURL(f.base, name) + u, err := internal.SubURL(f.base, name) if err != nil { return nil, err } @@ -209,7 +209,7 @@ func (f *consulFS) ReadFile(name string) ([]byte, error) { } } - u, err := subURL(f.base, name) + u, err := internal.SubURL(f.base, name) if err != nil { return nil, err } @@ -229,15 +229,6 @@ func (f *consulFS) ReadFile(name string) ([]byte, error) { return kvPair.Value, nil } -func subURL(base *url.URL, name string) (*url.URL, error) { - rel, err := url.Parse(name) - if err != nil { - return nil, err - } - - return base.ResolveReference(rel), nil -} - type consulFile struct { ctx context.Context name string @@ -398,7 +389,7 @@ func (f *consulFile) childFile(childName string) *consulFile { parent.Path += "/" } - childURL, _ := subURL(&parent, childName) + childURL, _ := internal.SubURL(&parent, childName) cf := &consulFile{ ctx: f.ctx, diff --git a/consulfs/consul_test.go b/consulfs/consul_test.go index 6667da9f..ff90a524 100644 --- a/consulfs/consul_test.go +++ b/consulfs/consul_test.go @@ -365,18 +365,6 @@ func TestReadDirN(t *testing.T) { assert.Len(t, de, 3) } -func TestSubURL(t *testing.T) { - base := tests.MustURL("https://example.com/dir/") - sub, err := subURL(base, "sub") - assert.NoError(t, err) - assert.Equal(t, "https://example.com/dir/sub", sub.String()) - - base = tests.MustURL("consul:///dir/") - sub, err = subURL(base, "sub/foo?param=foo") - assert.NoError(t, err) - assert.Equal(t, "consul:///dir/sub/foo?param=foo", sub.String()) -} - func TestStat(t *testing.T) { config := fakeConsulServer(t) diff --git a/httpfs/http.go b/httpfs/http.go index 251964c6..e3b9bdf9 100644 --- a/httpfs/http.go +++ b/httpfs/http.go @@ -104,7 +104,7 @@ func (f httpFS) Open(name string) (fs.File, error) { } } - u, err := f.subURL(name) + u, err := internal.SubURL(f.base, name) if err != nil { return nil, err } @@ -136,7 +136,7 @@ func (f httpFS) ReadFile(name string) ([]byte, error) { func (f httpFS) Sub(name string) (fs.FS, error) { fsys := f - u, err := f.subURL(name) + u, err := internal.SubURL(f.base, name) if err != nil { return nil, err } @@ -146,15 +146,6 @@ func (f httpFS) Sub(name string) (fs.FS, error) { return &fsys, nil } -func (f *httpFS) subURL(name string) (*url.URL, error) { - rel, err := url.Parse(name) - if err != nil { - return nil, err - } - - return f.base.ResolveReference(rel), nil -} - type httpFile struct { ctx context.Context body io.ReadCloser diff --git a/httpfs/http_test.go b/httpfs/http_test.go index f855793a..1ea0d373 100644 --- a/httpfs/http_test.go +++ b/httpfs/http_test.go @@ -2,6 +2,7 @@ package httpfs import ( "context" + "encoding/json" "fmt" "io" "io/fs" @@ -14,6 +15,7 @@ import ( "github.com/hairyhenderson/go-fsimpl" "github.com/hairyhenderson/go-fsimpl/internal/tests" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func setupHTTP(t *testing.T) *httptest.Server { @@ -35,6 +37,19 @@ func setupHTTP(t *testing.T) *httptest.Server { _, _ = w.Write([]byte(`{"msg": "hi there"}`)) }) + mux.HandleFunc("/params", func(w http.ResponseWriter, r *http.Request) { + // just returns params as JSON + w.Header().Set("Content-Type", "application/json") + + t.Logf("url: %v", r.URL) + t.Logf("params: %v", r.URL.Query()) + + err := json.NewEncoder(w).Encode(r.URL.Query()) + if err != nil { + t.Fatalf("error encoding: %v", err) + } + }) + srv := httptest.NewServer(mux) t.Cleanup(srv.Close) @@ -83,6 +98,21 @@ func TestHttpFS(t *testing.T) { _, err = fs.Stat(fsys, "bogus") assert.Error(t, err) + + t.Run("base URL query params are preserved", func(t *testing.T) { + fsys, _ = New(tests.MustURL(srv.URL + "/?foo=bar&baz=qux")) + fsys = fsimpl.WithContextFS(ctx, fsys) + + f, err := fsys.Open("params") + assert.NoError(t, err) + + defer f.Close() + + body, err := io.ReadAll(f) + require.NoError(t, err) + + assert.JSONEq(t, `{"foo":["bar"],"baz":["qux"]}`, string(body)) + }) } func setupExampleHTTPServer() *httptest.Server { diff --git a/internal/url.go b/internal/url.go new file mode 100644 index 00000000..3e75dd93 --- /dev/null +++ b/internal/url.go @@ -0,0 +1,26 @@ +package internal + +import "net/url" + +func SubURL(base *url.URL, name string) (*url.URL, error) { + rel, err := url.Parse(name) + if err != nil { + return nil, err + } + + u := base.ResolveReference(rel) + + // also merge query params + if base.RawQuery != "" { + bq := base.Query() + rq := rel.Query() + + for k := range rq { + bq.Set(k, rq.Get(k)) + } + + u.RawQuery = bq.Encode() + } + + return u, nil +} diff --git a/internal/url_test.go b/internal/url_test.go new file mode 100644 index 00000000..2f1deb68 --- /dev/null +++ b/internal/url_test.go @@ -0,0 +1,31 @@ +package internal + +import ( + "testing" + + "github.com/hairyhenderson/go-fsimpl/internal/tests" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSubURL(t *testing.T) { + base := tests.MustURL("https://example.com/dir/") + sub, err := SubURL(base, "sub") + assert.NoError(t, err) + assert.Equal(t, "https://example.com/dir/sub", sub.String()) + + base = tests.MustURL("consul:///dir/") + sub, err = SubURL(base, "sub/foo?param=foo") + assert.NoError(t, err) + assert.Equal(t, "consul:///dir/sub/foo?param=foo", sub.String()) + + base = tests.MustURL("vault:///dir/?param1=foo¶m2=bar") + sub, err = SubURL(base, "sub/foo") + require.NoError(t, err) + assert.Equal(t, "vault:///dir/sub/foo?param1=foo¶m2=bar", sub.String()) + + base = tests.MustURL("consul:///dir/?param1=foo¶m2=bar") + sub, err = SubURL(base, "sub/foo?param3=baz") + require.NoError(t, err) + assert.Equal(t, "consul:///dir/sub/foo?param1=foo¶m2=bar¶m3=baz", sub.String()) +} diff --git a/vaultfs/vault.go b/vaultfs/vault.go index 60704191..a7006d17 100644 --- a/vaultfs/vault.go +++ b/vaultfs/vault.go @@ -159,7 +159,7 @@ func (f vaultFS) Open(name string) (fs.File, error) { return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrInvalid} } - u, err := f.subURL(name) + u, err := internal.SubURL(f.base, name) if err != nil { return nil, err } @@ -193,29 +193,6 @@ func (f vaultFS) ReadFile(name string) ([]byte, error) { return b, nil } -func (f *vaultFS) subURL(name string) (*url.URL, error) { - rel, err := url.Parse(name) - if err != nil { - return nil, err - } - - u := f.base.ResolveReference(rel) - - // also merge query params - if f.base.RawQuery != "" { - bq := f.base.Query() - rq := rel.Query() - - for k := range rq { - bq.Set(k, rq.Get(k)) - } - - u.RawQuery = bq.Encode() - } - - return u, nil -} - // newVaultFile opens a vault file/dir for reading - if this file is not closed // a vault token may be leaked! func newVaultFile(ctx context.Context, name string, u *url.URL, client *refCountedClient, auth api.AuthMethod) *vaultFile { diff --git a/vaultfs/vault_test.go b/vaultfs/vault_test.go index 02654d17..a776ef1a 100644 --- a/vaultfs/vault_test.go +++ b/vaultfs/vault_test.go @@ -348,29 +348,6 @@ func TestReadDirN(t *testing.T) { assert.Len(t, de, 3) } -func TestSubURL(t *testing.T) { - fsys := &vaultFS{base: tests.MustURL("https://example.com/v1/secret/")} - - sub, err := fsys.subURL("foo") - assert.NoError(t, err) - assert.Equal(t, "https://example.com/v1/secret/foo", sub.String()) - - fsys = &vaultFS{base: tests.MustURL("vault:///v1/secret/")} - sub, err = fsys.subURL("sub/foo?param=foo") - assert.NoError(t, err) - assert.Equal(t, "vault:///v1/secret/sub/foo?param=foo", sub.String()) - - fsys = &vaultFS{base: tests.MustURL("vault:///v1/secret/?param1=foo¶m2=bar")} - sub, err = fsys.subURL("sub/foo") - assert.NoError(t, err) - assert.Equal(t, "vault:///v1/secret/sub/foo?param1=foo¶m2=bar", sub.String()) - - fsys = &vaultFS{base: tests.MustURL("vault:///v1/secret/?param1=foo¶m2=bar")} - sub, err = fsys.subURL("sub/foo?param3=baz") - assert.NoError(t, err) - assert.Equal(t, "vault:///v1/secret/sub/foo?param1=foo¶m2=bar¶m3=baz", sub.String()) -} - func TestStat(t *testing.T) { v := newRefCountedClient(fakeVaultServer(t))