Skip to content

Commit

Permalink
fix(vaultfs): Support KVv2 mounts that contain slashes (#738)
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Henderson <[email protected]>
  • Loading branch information
hairyhenderson authored Jun 30, 2024
1 parent 6fb5d96 commit 87b1367
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 326 deletions.
128 changes: 128 additions & 0 deletions internal/tests/fakevault/fakevault.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package fakevault

import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/hashicorp/vault/api"
"github.com/stretchr/testify/assert"
)

func mountHandler(w http.ResponseWriter, _ *http.Request) {
mounts := map[string]interface{}{
"secret/": map[string]interface{}{
"type": "kv",
},
}

resp := map[string]interface{}{
"data": map[string]interface{}{
"secret": mounts,
},
}

enc := json.NewEncoder(w)
_ = enc.Encode(resp)
}

//nolint:gocyclo
func vaultHandler(t *testing.T, files map[string]fakeSecret) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "LIST" || (r.Method == http.MethodGet && r.URL.Query().Get("list") == "true") {
r.URL.Path += "/"

// transform back to list for simplicity
r.Method = "LIST"
vals := r.URL.Query()
vals.Del("list")
r.URL.RawQuery = vals.Encode()
}

data, ok := files[r.URL.Path]
if !ok {
w.WriteHeader(http.StatusNotFound)

return
}

q := r.URL.Query()
for k, v := range q {
if k == "method" {
assert.Equal(t, v[0], r.Method)
}
}

body := map[string]interface{}{}

if r.Body != nil {
dec := json.NewDecoder(r.Body)
_ = dec.Decode(&body)

defer r.Body.Close()

if p, ok := body["param"]; ok {
data.Param = p.(string)
}
}

switch r.Method {
case http.MethodGet:
assert.Empty(t, data.Param, r.URL)
assert.NotEmpty(t, data.Value, r.URL)
case http.MethodPost:
assert.NotEmpty(t, data.Param, r.URL)
case "LIST":
assert.NotEmpty(t, data.Keys, r.URL)
}

t.Logf("encoding %#v", data)

enc := json.NewEncoder(w)
_ = enc.Encode(map[string]interface{}{"data": data})
})
}

type fakeSecret struct {
Value string `json:"value,omitempty"`
Param string `json:"param,omitempty"`
Keys []string `json:"keys,omitempty"`
}

func Server(t *testing.T) *api.Client {
files := map[string]fakeSecret{
"/v1/secret/": {Keys: []string{"foo", "bar", "foo/"}},
"/v1/secret/foo": {Value: "foo"},
"/v1/secret/bar": {Value: "foo"},
"/v1/secret/foo/": {Keys: []string{"foo", "bar", "bazDir/"}},
"/v1/secret/foo/foo": {Value: "foo"},
"/v1/secret/foo/bar": {Value: "foo"},
"/v1/secret/foo/bazDir/": {Keys: []string{"foo", "bar", "bazDir/"}},
}

mux := http.NewServeMux()

mux.HandleFunc("/v1/sys/internal/ui/mounts", mountHandler)
mux.Handle("/", vaultHandler(t, files))

return FakeVault(t, mux)
}

func FakeVault(t *testing.T, handler http.Handler) *api.Client {
srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)

tr := &http.Transport{
Proxy: func(_ *http.Request) (*url.URL, error) {
return url.Parse(srv.URL)
},
}
httpClient := &http.Client{Transport: tr}
config := &api.Config{Address: srv.URL, HttpClient: httpClient}

c, _ := api.NewClient(config)

return c
}
57 changes: 57 additions & 0 deletions internal/tests/integration/vaultfs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ func setupVaultFSTest(ctx context.Context, t *testing.T) string {
}`)
require.NoError(t, err)

err = client.Sys().PutPolicyWithContext(ctx, "kv2pol",
`path "kv2/*" {
capabilities = ["read"]
}
path "a/b/c/*" {
capabilities = ["read", "list"]
}`)
require.NoError(t, err)

return addr
}

Expand Down Expand Up @@ -587,6 +596,7 @@ func TestVaultFS_List(t *testing.T) {
assert.Equal(t, "foo", de[1].Name())
}

//nolint:funlen
func TestVaultFS_KVv2(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down Expand Up @@ -647,4 +657,51 @@ func TestVaultFS_KVv2(t *testing.T) {

// v1 should have an earlier mod time than v2
assert.NotEqual(t, v2Time, fi.ModTime())

t.Run("mount with slashes", func(t *testing.T) {
mount := "a/b/c"
err = client.Sys().MountWithContext(ctx, mount, &api.MountInput{
Type: "kv",
Options: map[string]string{"version": "2"},
})
require.NoError(t, err)

s, err = client.KVv2(mount).Put(ctx, "d/e/f", map[string]interface{}{"e": "f"}, api.WithCheckAndSet(0))
require.NoError(t, err)

tok, err := tokenCreate(ctx, client, "kv2pol", 5)
require.NoError(t, err)

readClient, err := api.NewClient(&api.Config{Address: "http://" + addr})
require.NoError(t, err)

readClient.SetToken(tok)

fsys, err := vaultfs.New(tests.MustURL("http://" + addr))
require.NoError(t, err)

fsys = vaultauth.WithAuthMethod(vaultauth.NewTokenAuth(tok), fsys)
fsys = fsimpl.WithContextFS(ctx, fsys)

t.Run("can read", func(t *testing.T) {
f, err := fsys.Open(mount + "/d/e/f")
require.NoError(t, err)

b, err = io.ReadAll(f)
require.NoError(t, err)
assert.Equal(t, `{"e":"f"}`, string(b))
})

t.Run("can list", func(t *testing.T) {
des, err := fs.ReadDir(fsys, mount+"/d")
require.NoError(t, err)
assert.Len(t, des, 1)
assert.Equal(t, "e", des[0].Name())

des, err = fs.ReadDir(fsys, mount+"/d/e")
require.NoError(t, err)
assert.Len(t, des, 1)
assert.Equal(t, "f", des[0].Name())
})
})
}
9 changes: 5 additions & 4 deletions vaultfs/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import (
"testing"
"testing/fstest"

"github.com/hairyhenderson/go-fsimpl/internal/tests/fakevault"
"github.com/hashicorp/vault/api"
"github.com/stretchr/testify/assert"
)

func TestEnvAuthLogin(t *testing.T) {
v := fakeVaultServer(t)
v := fakevault.Server(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -85,7 +86,7 @@ func TestAppRoleAuthMethod(t *testing.T) {
mount := "approle"
token := "approletoken"

client := fakeVault(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
client := fakevault.FakeVault(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/auth/"+mount+"/login", r.URL.Path)

out := map[string]interface{}{
Expand Down Expand Up @@ -157,7 +158,7 @@ func TestUserPassAuthMethod(t *testing.T) {
_ = enc.Encode(out)
})

client := fakeVault(t, mux)
client := fakevault.FakeVault(t, mux)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -194,7 +195,7 @@ func TestGitHubAuthMethod(t *testing.T) {
token := "sometoken"
ghtoken := "abcd1234"

client := fakeVault(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
client := fakevault.FakeVault(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/auth/"+mount+"/login", r.URL.Path)

out := map[string]interface{}{
Expand Down
Loading

0 comments on commit 87b1367

Please sign in to comment.