From 8c0f527ee0b2c79ccd49c89959e7287b54dce39d Mon Sep 17 00:00:00 2001 From: Dave Henderson Date: Sun, 17 Nov 2024 10:56:03 -0500 Subject: [PATCH] feat(vaultfs): Add extension methods WithClient and WithConfig (#881) Signed-off-by: Dave Henderson --- vaultfs/extensions.go | 39 ++++++++++++++++++++++++ vaultfs/vault.go | 45 +++++++++++++++++++++++++++- vaultfs/vault_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 vaultfs/extensions.go diff --git a/vaultfs/extensions.go b/vaultfs/extensions.go new file mode 100644 index 00000000..71707416 --- /dev/null +++ b/vaultfs/extensions.go @@ -0,0 +1,39 @@ +package vaultfs + +import ( + "io/fs" + + "github.com/hashicorp/vault/api" +) + +type withClienter interface { + WithClient(client *api.Client) fs.FS +} + +// WithClient injects a Vault client into the filesystem fs, if the +// filesystem supports it (i.e. is a [FS], or some other type with a +// WithClient method). The current client will be replaced. It is the +// caller's responsibility to ensure the client is configured correctly. +func WithClient(client *api.Client, fsys fs.FS) fs.FS { + if cfsys, ok := fsys.(withClienter); ok { + return cfsys.WithClient(client) + } + + return fsys +} + +type withConfiger interface { + WithConfig(config *api.Config) fs.FS +} + +// WithConfig injects a Vault configuration into the filesystem fs, if the +// filesystem supports it (i.e. is a [FS], or some other type with a +// WithConfig method). The current client will be replaced. If the +// configuration is invalid, an error will be logged and nil will be returned. +func WithConfig(config *api.Config, fsys fs.FS) fs.FS { + if cfsys, ok := fsys.(withConfiger); ok { + return cfsys.WithConfig(config) + } + + return fsys +} diff --git a/vaultfs/vault.go b/vaultfs/vault.go index f361ee2f..73023d2e 100644 --- a/vaultfs/vault.go +++ b/vaultfs/vault.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "io/fs" + "log/slog" "net/http" "net/url" "path" @@ -63,7 +64,8 @@ func New(u *url.URL) (fs.FS, error) { return nil, fmt.Errorf("vault client creation failed: %w", err) } - fsys := newWithVaultClient(u, newRefCountedClient(c)) + fsys := newWithVaultClient(u, nil) + fsys = WithClient(c, fsys).(*vaultFS) fsys.auth = vaultauth.NewTokenAuth("") return fsys, nil @@ -115,6 +117,8 @@ var ( _ fs.ReadFileFS = (*vaultFS)(nil) _ internal.WithContexter = (*vaultFS)(nil) _ internal.WithHeaderer = (*vaultFS)(nil) + _ withClienter = (*vaultFS)(nil) + _ withConfiger = (*vaultFS)(nil) ) func (f vaultFS) URL() string { @@ -148,6 +152,45 @@ func (f *vaultFS) WithHeader(headers http.Header) fs.FS { return &fsys } +func (f *vaultFS) WithClient(client *api.Client) fs.FS { + if client == nil { + return f + } + + fsys := *f + fsys.client = newRefCountedClient(client) + + return &fsys +} + +func (f *vaultFS) WithConfig(config *api.Config) fs.FS { + if config == nil { + return f + } + + // handle compound URL scheme not supported by the client, but only if the + // URL has a host part set - otherwise use the scheme from $VAULT_ADDR, as + // set by api.DefaultConfig() above + if f.base.Host != "" { + scheme := strings.TrimPrefix(f.base.Scheme, "vault+") + if scheme == "vault" { + scheme = "https" + } + + config.Address = scheme + "://" + f.base.Host + } + + client, err := api.NewClient(config) + if err != nil { + slog.ErrorContext(f.ctx, "vaultfs: failed to create vault client with user-supplied configuration", + slog.Any("error", err)) + + return nil + } + + return f.WithClient(client) +} + func (f vaultFS) WithAuthMethod(auth api.AuthMethod) fs.FS { fsys := f fsys.auth = auth diff --git a/vaultfs/vault_test.go b/vaultfs/vault_test.go index becc1f06..8b6f6494 100644 --- a/vaultfs/vault_test.go +++ b/vaultfs/vault_test.go @@ -415,3 +415,73 @@ func TestFindMountInfo(t *testing.T) { assert.EqualValues(t, d.expected, actual) } } + +func TestWithConfig(t *testing.T) { + cl := fakevault.Server(t) + + t.Run("config provided", func(t *testing.T) { + config := cl.CloneConfig() + fsys := WithAuthMethod( + TokenAuthMethod("blargh"), + // fsys without vault client - will panic unless a client is injected + newWithVaultClient(tests.MustURL("vault:///secret/"), nil), + ) + fsys = WithConfig(config, fsys).(*vaultFS) + + f, err := fsys.Open("foo") + require.NoError(t, err) + + fi, err := f.Stat() + require.NoError(t, err) + assert.Equal(t, "application/json", fsimpl.ContentType(fi)) + }) + + t.Run("bad config errors with nil fs", func(t *testing.T) { + config := cl.CloneConfig() + config.Address = "bad url://" + + vaultFs := newWithVaultClient(tests.MustURL("vault:///secret/"), nil) + assert.Nil(t, WithConfig(config, vaultFs)) + }) + + t.Run("nil config ignored", func(t *testing.T) { + vaultFs := newWithVaultClient(tests.MustURL("vault:///secret/"), nil) + fsys := WithConfig(nil, vaultFs) + assert.Same(t, vaultFs, fsys) + }) + + t.Run("URL with host overrides what's in the config", func(t *testing.T) { + config := api.DefaultConfig() + testdata := []struct { + url, addr string + }{ + {"vault+https://example.com/secret/", "https://example.com"}, + {"vault://example.com/secret/foo", "https://example.com"}, + {"vault+http://example.com/secret/", "http://example.com"}, + } + + for _, d := range testdata { + vaultFs := newWithVaultClient(tests.MustURL(d.url), nil) + vaultFs = WithConfig(config, vaultFs).(*vaultFS) + assert.Equal(t, d.addr, vaultFs.client.CloneConfig().Address) + } + }) +} + +func TestWithClient(t *testing.T) { + cl := fakevault.Server(t) + + fsys := WithAuthMethod( + TokenAuthMethod("blargh"), + // fsys without vault client - will panic unless a client is injected + newWithVaultClient(tests.MustURL("vault:///secret/"), nil), + ) + fsys = WithClient(cl, fsys).(*vaultFS) + + f, err := fsys.Open("foo") + require.NoError(t, err) + + fi, err := f.Stat() + require.NoError(t, err) + assert.Equal(t, "application/json", fsimpl.ContentType(fi)) +}