Skip to content

Commit

Permalink
fix(vaultfs): bug when Stat builds mount info from #738
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Henderson <[email protected]>
  • Loading branch information
hairyhenderson committed Jun 30, 2024
1 parent 87b1367 commit afa1257
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 59 deletions.
2 changes: 1 addition & 1 deletion autofs/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (c *autoFS) New(u *url.URL) (fs.FS, error) {
}

func initMux() fsimpl.FSMux {
return sync.OnceValue[fsimpl.FSMux](func() fsimpl.FSMux {
return sync.OnceValue(func() fsimpl.FSMux {
mux := fsimpl.NewMux()
mux.Add(awsimdsfs.FS)
mux.Add(awssmfs.FS)
Expand Down
144 changes: 86 additions & 58 deletions vaultfs/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,84 +234,103 @@ type mountInfo struct {

var _ fs.ReadDirFile = (*vaultFile)(nil)

func (f *vaultFile) newRequest(method string) (*api.Request, error) {
q := f.u.Query()
if len(q) > 0 && method == http.MethodGet {
method = http.MethodPost
func (f *vaultFile) request(method string) (*api.KVSecret, *api.Secret, error) {
mountInfo, err := f.getMountInfo(f.ctx)
if err != nil {
return nil, nil, fmt.Errorf("get mount info: %w", err)
}

req := f.client.NewRequest(method, f.u.Path)
if method == http.MethodGet {
req.Params = q
} else if len(q) > 0 {
data := map[string]interface{}{}
// it's a KVv2 Get operation with the right type, version, and especially if
// the secret path is set - otherwise it might need to be a list operation
if mountInfo.secretPath != "" && mountInfo.Type == "kv" && mountInfo.Options["version"] == "2" {
var kv *api.KVSecret

for k, vs := range q {
for _, v := range vs {
data[k] = v
}
}

err := req.SetJSONBody(data)
kv, err = f.kv2request(f.ctx, mountInfo.name, mountInfo.secretPath)
if err != nil {
return nil, err
return nil, nil, fmt.Errorf("failed to get KV v2 secret: %w", err)
}
}

return req, nil
}
return kv, nil, nil
}

func (f *vaultFile) request(method string) (kv *api.KVSecret, secret *api.Secret, err error) {
mountInfo, err := f.getMountInfo(f.ctx)
secret, err := f.rawRequest(method)
if err != nil {
return nil, nil, fmt.Errorf("get mount info: %w", err)
return nil, nil, err
}

if mountInfo.Type == "kv" && mountInfo.Options["version"] == "2" {
kv, err = f.kv2request(f.ctx, mountInfo.name, mountInfo.secretPath)
return nil, secret, nil
}

func (f *vaultFile) kv2request(ctx context.Context, mount, secret string) (kv *api.KVSecret, err error) {
kv2client := f.client.KVv2(mount)

version := 0
if ver := f.u.Query().Get("version"); ver != "" {
version, err = strconv.Atoi(ver)
if err != nil {
return nil, nil, fmt.Errorf("failed to get KV v2 secret: %w", err)
return nil, fmt.Errorf("invalid version %q requested: %w", ver, err)
}

return kv, nil, nil
}

return kv2client.GetVersion(ctx, secret, version)
}

// rawRequest makes a raw request to Vault by constructing a new request from
// the method and URL, and returns the parsed secret.
//
// This should probably be replaced with a call to the logical client (either
// Read, Write, or List) in the future, especially as the RawRequestWithContext
// method is deprecated.
func (f *vaultFile) rawRequest(method string) (*api.Secret, error) {
req, err := f.newRequest(method)
if err != nil {
return nil, nil, fmt.Errorf("failed to create vault request: %w", err)
return nil, fmt.Errorf("failed to create vault request: %w", err)
}

//nolint:staticcheck
resp, err := f.client.RawRequestWithContext(f.ctx, req)
if err != nil {
return nil, nil, fmt.Errorf("http %s %s failed with: %w", method, f.u.Path,
return nil, fmt.Errorf("http %s %s failed with: %w", method, f.u.Path,
vaultFSError(err))
}

if resp.StatusCode == 0 || resp.StatusCode >= 400 {
return nil, nil, fmt.Errorf("http %s %s failed with status %d", method, f.u, resp.StatusCode)
return nil, fmt.Errorf("http %s %s failed with status %d", method, f.u, resp.StatusCode)
}

secret, err = api.ParseSecret(resp.Body)
secret, err := api.ParseSecret(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse vault secret: %w", err)
return nil, fmt.Errorf("failed to parse vault secret: %w", err)
}

return nil, secret, nil
return secret, nil
}

func (f *vaultFile) kv2request(ctx context.Context, mount, secret string) (kv *api.KVSecret, err error) {
kv2client := f.client.KVv2(mount)
func (f *vaultFile) newRequest(method string) (*api.Request, error) {
values := f.u.Query()
if len(values) > 0 && method == http.MethodGet {
method = http.MethodPost
}

version := 0
if ver := f.u.Query().Get("version"); ver != "" {
version, err = strconv.Atoi(ver)
req := f.client.NewRequest(method, f.u.Path)
if method == http.MethodGet {
req.Params = values
} else if len(values) > 0 {
data := map[string]interface{}{}

for k, vs := range values {
for _, v := range vs {
data[k] = v
}
}

err := req.SetJSONBody(data)
if err != nil {
return nil, fmt.Errorf("invalid version %q requested: %w", ver, err)
return nil, err
}
}

return kv2client.GetVersion(ctx, secret, version)
return req, nil
}

// Close the file. Will error on second call. Decrements the ref count on first
Expand Down Expand Up @@ -560,8 +579,6 @@ func vaultFSError(err error) error {
// getMountInfo calls the undocumented sys/internal/ui/mounts endpoint to set
// the file's mount metadata. This is used in preference to the sys/mounts
// API because this one works read-only roles (!). The result is cached.
//
//nolint:gocyclo
func (f *vaultFile) getMountInfo(ctx context.Context) (*mountInfo, error) {
if f.mountInfo != nil {
return f.mountInfo, nil
Expand Down Expand Up @@ -591,11 +608,28 @@ func (f *vaultFile) getMountInfo(ctx context.Context) (*mountInfo, error) {
return nil, fmt.Errorf("unexpected mount info format: %#v", s.Data)
}

for k, v := range rawMounts {
if strings.HasPrefix(f.u.Path, "/v1/"+k) {
v, ok := v.(map[string]interface{})
mi, err := findMountInfo(f.u.Path, rawMounts)
if err != nil {
return nil, err
}

if mi == nil {
return nil, fmt.Errorf("mount not found for %q", f.u.Path)
}

f.mountInfo = mi

return f.mountInfo, nil
}

func findMountInfo(rawFilePath string, rawMounts map[string]interface{}) (*mountInfo, error) {
for mountName, mountOpts := range rawMounts {
mountPrefix := path.Join("/v1", mountName)

if strings.HasPrefix(rawFilePath, mountPrefix) {
v, ok := mountOpts.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected mount info format for %q: %#v", k, v)
return nil, fmt.Errorf("unexpected mount info format for %q: %#v", mountName, v)
}

mount := &api.MountOutput{Type: v["type"].(string)}
Expand All @@ -608,21 +642,15 @@ func (f *vaultFile) getMountInfo(ctx context.Context) (*mountInfo, error) {
}
}

f.mountInfo = &mountInfo{
name: k,
secretPath: strings.TrimPrefix(f.u.Path, "/v1/"+k),
MountOutput: mount,
}
// build secretPath - it's the part after the mount name, including the
// / prefix
spath := strings.TrimPrefix(rawFilePath, mountPrefix)

break
return &mountInfo{name: mountName, secretPath: spath, MountOutput: mount}, nil
}
}

if f.mountInfo == nil {
return nil, fmt.Errorf("mount not found for %q", f.u.Path)
}

return f.mountInfo, nil
return nil, nil
}

func createdTimeFromData(kvsec *api.KVSecret) time.Time {
Expand Down
58 changes: 58 additions & 0 deletions vaultfs/vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,61 @@ func TestFileAuthCaching(t *testing.T) {
require.NoError(t, err)
assert.Empty(t, v.Token())
}

func TestFindMountInfo(t *testing.T) {
testdata := []struct {
expected *mountInfo
mountOpts interface{}
mountName string
rawFilePath string
}{
{
// no match
rawFilePath: "/v1/secret/a/b/c", mountName: "potato/",
mountOpts: map[string]interface{}{
"type": "kv",
"options": map[string]interface{}{"version": "1"},
}, expected: nil,
},
{
rawFilePath: "/v1/secret/a/b/c", mountName: "secret/",
mountOpts: map[string]interface{}{
"type": "kv", "options": map[string]interface{}{"version": "1"},
},
expected: &mountInfo{
secretPath: "/a/b/c",
name: "secret/",
MountOutput: &api.MountOutput{
Type: "kv",
Options: map[string]string{"version": "1"},
},
},
},
{
// just the mount, e.g. for list
rawFilePath: "/v1/kv2", mountName: "kv2/",
mountOpts: map[string]interface{}{
"type": "kv", "options": map[string]interface{}{"version": "2"},
},
expected: &mountInfo{
secretPath: "",
name: "kv2/",
MountOutput: &api.MountOutput{
Type: "kv",
Options: map[string]string{"version": "2"},
},
},
},
}

for _, d := range testdata {
rawMounts := map[string]interface{}{
"bogus/": map[string]interface{}{},
d.mountName: d.mountOpts,
}

actual, err := findMountInfo(d.rawFilePath, rawMounts)
require.NoError(t, err)
assert.EqualValues(t, d.expected, actual)
}
}

0 comments on commit afa1257

Please sign in to comment.