Skip to content

Commit

Permalink
feat(vaultfs): Add extension methods WithClient and WithConfig (#881)
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Henderson <[email protected]>
  • Loading branch information
hairyhenderson authored Nov 17, 2024
1 parent 915267e commit 8c0f527
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 1 deletion.
39 changes: 39 additions & 0 deletions vaultfs/extensions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package vaultfs

import (
"io/fs"

"github.com/hashicorp/vault/api"
)

type withClienter interface {
WithClient(client *api.Client) fs.FS
}

// WithClient injects a Vault client into the filesystem fs, if the
// filesystem supports it (i.e. is a [FS], or some other type with a
// WithClient method). The current client will be replaced. It is the
// caller's responsibility to ensure the client is configured correctly.
func WithClient(client *api.Client, fsys fs.FS) fs.FS {
if cfsys, ok := fsys.(withClienter); ok {
return cfsys.WithClient(client)
}

return fsys
}

type withConfiger interface {
WithConfig(config *api.Config) fs.FS
}

// WithConfig injects a Vault configuration into the filesystem fs, if the
// filesystem supports it (i.e. is a [FS], or some other type with a
// WithConfig method). The current client will be replaced. If the
// configuration is invalid, an error will be logged and nil will be returned.
func WithConfig(config *api.Config, fsys fs.FS) fs.FS {
if cfsys, ok := fsys.(withConfiger); ok {
return cfsys.WithConfig(config)
}

return fsys
}
45 changes: 44 additions & 1 deletion vaultfs/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"io/fs"
"log/slog"
"net/http"
"net/url"
"path"
Expand Down Expand Up @@ -63,7 +64,8 @@ func New(u *url.URL) (fs.FS, error) {
return nil, fmt.Errorf("vault client creation failed: %w", err)
}

fsys := newWithVaultClient(u, newRefCountedClient(c))
fsys := newWithVaultClient(u, nil)
fsys = WithClient(c, fsys).(*vaultFS)
fsys.auth = vaultauth.NewTokenAuth("")

return fsys, nil
Expand Down Expand Up @@ -115,6 +117,8 @@ var (
_ fs.ReadFileFS = (*vaultFS)(nil)
_ internal.WithContexter = (*vaultFS)(nil)
_ internal.WithHeaderer = (*vaultFS)(nil)
_ withClienter = (*vaultFS)(nil)
_ withConfiger = (*vaultFS)(nil)
)

func (f vaultFS) URL() string {
Expand Down Expand Up @@ -148,6 +152,45 @@ func (f *vaultFS) WithHeader(headers http.Header) fs.FS {
return &fsys
}

func (f *vaultFS) WithClient(client *api.Client) fs.FS {
if client == nil {
return f
}

fsys := *f
fsys.client = newRefCountedClient(client)

return &fsys
}

func (f *vaultFS) WithConfig(config *api.Config) fs.FS {
if config == nil {
return f
}

// handle compound URL scheme not supported by the client, but only if the
// URL has a host part set - otherwise use the scheme from $VAULT_ADDR, as
// set by api.DefaultConfig() above
if f.base.Host != "" {
scheme := strings.TrimPrefix(f.base.Scheme, "vault+")
if scheme == "vault" {
scheme = "https"
}

config.Address = scheme + "://" + f.base.Host
}

client, err := api.NewClient(config)
if err != nil {
slog.ErrorContext(f.ctx, "vaultfs: failed to create vault client with user-supplied configuration",
slog.Any("error", err))

return nil
}

return f.WithClient(client)
}

func (f vaultFS) WithAuthMethod(auth api.AuthMethod) fs.FS {
fsys := f
fsys.auth = auth
Expand Down
70 changes: 70 additions & 0 deletions vaultfs/vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,73 @@ func TestFindMountInfo(t *testing.T) {
assert.EqualValues(t, d.expected, actual)
}
}

func TestWithConfig(t *testing.T) {
cl := fakevault.Server(t)

t.Run("config provided", func(t *testing.T) {
config := cl.CloneConfig()
fsys := WithAuthMethod(
TokenAuthMethod("blargh"),
// fsys without vault client - will panic unless a client is injected
newWithVaultClient(tests.MustURL("vault:///secret/"), nil),
)
fsys = WithConfig(config, fsys).(*vaultFS)

f, err := fsys.Open("foo")
require.NoError(t, err)

fi, err := f.Stat()
require.NoError(t, err)
assert.Equal(t, "application/json", fsimpl.ContentType(fi))
})

t.Run("bad config errors with nil fs", func(t *testing.T) {
config := cl.CloneConfig()
config.Address = "bad url://"

vaultFs := newWithVaultClient(tests.MustURL("vault:///secret/"), nil)
assert.Nil(t, WithConfig(config, vaultFs))
})

t.Run("nil config ignored", func(t *testing.T) {
vaultFs := newWithVaultClient(tests.MustURL("vault:///secret/"), nil)
fsys := WithConfig(nil, vaultFs)
assert.Same(t, vaultFs, fsys)
})

t.Run("URL with host overrides what's in the config", func(t *testing.T) {
config := api.DefaultConfig()
testdata := []struct {
url, addr string
}{
{"vault+https://example.com/secret/", "https://example.com"},
{"vault://example.com/secret/foo", "https://example.com"},
{"vault+http://example.com/secret/", "http://example.com"},
}

for _, d := range testdata {
vaultFs := newWithVaultClient(tests.MustURL(d.url), nil)
vaultFs = WithConfig(config, vaultFs).(*vaultFS)
assert.Equal(t, d.addr, vaultFs.client.CloneConfig().Address)
}
})
}

func TestWithClient(t *testing.T) {
cl := fakevault.Server(t)

fsys := WithAuthMethod(
TokenAuthMethod("blargh"),
// fsys without vault client - will panic unless a client is injected
newWithVaultClient(tests.MustURL("vault:///secret/"), nil),
)
fsys = WithClient(cl, fsys).(*vaultFS)

f, err := fsys.Open("foo")
require.NoError(t, err)

fi, err := f.Stat()
require.NoError(t, err)
assert.Equal(t, "application/json", fsimpl.ContentType(fi))
}

0 comments on commit 8c0f527

Please sign in to comment.