Skip to content

Commit

Permalink
Simplify vault auth (#17)
Browse files Browse the repository at this point in the history
* fix token auth
* simplify code of vault auth
  * refactor evaluation of vault secret to client
  * remove unnecessary interfaces and generics
  • Loading branch information
Argelbargel authored Oct 5, 2023
1 parent 988546e commit e2deaf7
Show file tree
Hide file tree
Showing 23 changed files with 344 additions and 369 deletions.
4 changes: 2 additions & 2 deletions internal/agent/snapshot-agent-config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestReadCompleteConfig(t *testing.T) {
configFile := "../../testdata/complete.yaml"

expectedConfig := SnapshotAgentConfig{
Vault: vault.ClientConfig{
Vault: vault.VaultClientConfig{
Url: "https://example.com:8200",
Insecure: true,
Timeout: 5 * time.Minute,
Expand Down Expand Up @@ -144,7 +144,7 @@ func TestReadConfigSetsDefaultValues(t *testing.T) {
configFile := "../../testdata/snapshots.yaml"

expectedConfig := SnapshotAgentConfig{
Vault: vault.ClientConfig{
Vault: vault.VaultClientConfig{
Url: "http://127.0.0.1:8200",
Insecure: false,
Timeout: time.Minute,
Expand Down
2 changes: 1 addition & 1 deletion internal/agent/snapshot-agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

// SnapshotAgentConfig is the root of the agent-configuration
type SnapshotAgentConfig struct {
Vault vault.ClientConfig
Vault vault.VaultClientConfig
Snapshots SnapshotsConfig
}

Expand Down
14 changes: 8 additions & 6 deletions internal/agent/snapshot-agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"github.com/Argelbargel/vault-raft-snapshot-agent/internal/agent/storage"
"github.com/Argelbargel/vault-raft-snapshot-agent/internal/agent/vault"
"github.com/hashicorp/vault/api"
"io"
"testing"
"time"
Expand Down Expand Up @@ -257,8 +258,8 @@ func TestUpdateReschedulesSnapshots(t *testing.T) {
assert.Equal(t, newManager, agent.manager)
}

func newClient(api *clientVaultAPIStub) *vault.Client[any, clientVaultAPIAuthStub] {
return vault.NewClient[any, clientVaultAPIAuthStub](api, clientVaultAPIAuthStub{}, time.Time{})
func newClient(api *clientVaultAPIStub) *vault.VaultClient {
return vault.NewClient(api, clientVaultAPIAuthStub{}, time.Time{})
}

type clientVaultAPIStub struct {
Expand Down Expand Up @@ -297,14 +298,15 @@ func (stub *clientVaultAPIStub) IsLeader() (bool, error) {
return stub.leader, nil
}

func (stub *clientVaultAPIStub) RefreshAuth(ctx context.Context, auth clientVaultAPIAuthStub) (time.Duration, error) {
return auth.Login(ctx, nil)
func (stub *clientVaultAPIStub) RefreshAuth(ctx context.Context, auth api.AuthMethod) (time.Duration, error) {
_, err := auth.Login(ctx, nil)
return 0, err
}

type clientVaultAPIAuthStub struct{}

func (stub clientVaultAPIAuthStub) Login(_ context.Context, _ any) (time.Duration, error) {
return 0, nil
func (stub clientVaultAPIAuthStub) Login(_ context.Context, _ *api.Client) (*api.Secret, error) {
return nil, nil
}

type storageControllerFactoryStub struct {
Expand Down
34 changes: 15 additions & 19 deletions internal/agent/vault/auth/approle.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"github.com/Argelbargel/vault-raft-snapshot-agent/internal/agent/config/secret"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/api/auth/approle"
)

Expand All @@ -12,25 +13,20 @@ type AppRoleAuthConfig struct {
Empty bool
}

func createAppRoleAuth(config AppRoleAuthConfig) vaultAuthMethod[AppRoleAuthConfig, *approle.AppRoleAuth] {
return vaultAuthMethod[AppRoleAuthConfig, *approle.AppRoleAuth]{
config,
func(config AppRoleAuthConfig) (*approle.AppRoleAuth, error) {
roleId, err := config.RoleId.Resolve(true)
if err != nil {
return nil, err
}

secretId, err := config.SecretId.Resolve(true)
if err != nil {
return nil, err
}
func (c AppRoleAuthConfig) createAuthMethod() (api.AuthMethod, error) {
roleId, err := c.RoleId.Resolve(true)
if err != nil {
return nil, err
}

return approle.NewAppRoleAuth(
roleId,
&approle.SecretID{FromString: secretId},
approle.WithMountPath(config.Path),
)
},
secretId, err := c.SecretId.Resolve(true)
if err != nil {
return nil, err
}

return approle.NewAppRoleAuth(
roleId,
&approle.SecretID{FromString: secretId},
approle.WithMountPath(c.Path),
)
}
2 changes: 1 addition & 1 deletion internal/agent/vault/auth/approle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestCreateAppRoleAuth(t *testing.T) {
)
assert.NoError(t, err, "NewAppRoleAuth failed unexpectedly")

method, err := createAppRoleAuth(config).createAuthMethod()
method, err := config.createAuthMethod()
assert.NoError(t, err, "createAuthMethod failed unexpectedly")

assert.Equal(t, expectedAuthMethod, method)
Expand Down
61 changes: 19 additions & 42 deletions internal/agent/vault/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ package auth
import (
"context"
"fmt"
"github.com/Argelbargel/vault-raft-snapshot-agent/internal/agent/config/secret"
"github.com/Argelbargel/vault-raft-snapshot-agent/internal/agent/logging"
"time"

"github.com/hashicorp/vault/api"
)

Expand All @@ -18,65 +15,45 @@ type VaultAuthConfig struct {
Kubernetes KubernetesAuthConfig `default:"{\"Empty\": true}"`
LDAP LDAPAuthConfig `default:"{\"Empty\": true}"`
UserPass UserPassAuthConfig `default:"{\"Empty\": true}"`
Token secret.Secret
Token Token
}

type VaultAuth[C any] interface {
Login(ctx context.Context, client C) (time.Duration, error)
type vaultAuthMethodFactory interface {
createAuthMethod() (api.AuthMethod, error)
}

type vaultAuthMethod[C any, M api.AuthMethod] struct {
config C
methodFactory func(config C) (M, error)
type vaultAuthMethodImpl struct {
methodFactory vaultAuthMethodFactory
}

func CreateVaultAuth(config VaultAuthConfig) (VaultAuth[*api.Client], error) {
func CreateVaultAuth(config VaultAuthConfig) (api.AuthMethod, error) {
if !config.AppRole.Empty {
return createAppRoleAuth(config.AppRole), nil
return vaultAuthMethodImpl{config.AppRole}, nil
} else if !config.AWS.Empty {
return createAWSAuth(config.AWS), nil
return vaultAuthMethodImpl{config.AWS}, nil
} else if !config.Azure.Empty {
return createAzureAuth(config.Azure), nil
return vaultAuthMethodImpl{config.Azure}, nil
} else if !config.GCP.Empty {
return createGCPAuth(config.GCP), nil
return vaultAuthMethodImpl{config.GCP}, nil
} else if !config.Kubernetes.Empty {
return createKubernetesAuth(config.Kubernetes), nil
return vaultAuthMethodImpl{config.Kubernetes}, nil
} else if !config.LDAP.Empty {
return createLDAPAuth(config.LDAP), nil
return vaultAuthMethodImpl{config.LDAP}, nil
} else if !config.UserPass.Empty {
return createUserPassAuth(config.UserPass), nil
return vaultAuthMethodImpl{config.UserPass}, nil
} else if config.Token != "" {
return createTokenAuth(config.Token), nil
return vaultAuthMethodImpl{config.Token}, nil
} else {
return nil, fmt.Errorf("unknown authenticatin method")
}
}

func (am vaultAuthMethod[C, M]) Login(ctx context.Context, client *api.Client) (time.Duration, error) {
method, err := am.methodFactory(am.config)
if err != nil {
return 0, err
}

authSecret, err := client.Auth().Login(ctx, method)
if err != nil {
return 0, err
}

tokenTTL, err := authSecret.TokenTTL()
if err != nil {
return 0, err
}

tokenPolicies, err := authSecret.TokenPolicies()
func (am vaultAuthMethodImpl) Login(ctx context.Context, client *api.Client) (*api.Secret, error) {
method, err := am.methodFactory.createAuthMethod()
if err != nil {
return 0, err
return nil, err
}

logging.Debug("Successfully logged into vault", "ttl", tokenTTL, "policies", tokenPolicies)
return tokenTTL, nil
}

func (am vaultAuthMethod[C, M]) createAuthMethod() (M, error) {
return am.methodFactory(am.config)
logging.Debug("Logging into vault", "method", fmt.Sprintf("%T", method))
return client.Auth().Login(ctx, method)
}
55 changes: 32 additions & 23 deletions internal/agent/vault/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@ import (
"github.com/hashicorp/vault/api"
"github.com/stretchr/testify/assert"
"testing"
"time"
)

func TestVaultAuthMethod_Login_FailsIfMethodFactoryFails(t *testing.T) {
expectedErr := errors.New("methodFactory failed")
auth := vaultAuthMethod[any, api.AuthMethod]{
methodFactory: func(_ any) (api.AuthMethod, error) {
return nil, expectedErr
},
expectedErr := errors.New("create failed")
auth := vaultAuthMethodImpl{
authMethodFactoryStub{createErr: expectedErr},
}

_, err := auth.Login(context.Background(), nil)
Expand All @@ -23,9 +20,9 @@ func TestVaultAuthMethod_Login_FailsIfMethodFactoryFails(t *testing.T) {

func TestVaultAuthMethod_Login_FailsIfAuthMethodLoginFails(t *testing.T) {
expectedErr := errors.New("login failed")
auth := vaultAuthMethod[any, api.AuthMethod]{
methodFactory: func(_ any) (api.AuthMethod, error) {
return authMethodStub{loginError: expectedErr}, nil
auth := vaultAuthMethodImpl{
authMethodFactoryStub{
method: authMethodStub{loginError: expectedErr},
},
}

Expand All @@ -34,33 +31,45 @@ func TestVaultAuthMethod_Login_FailsIfAuthMethodLoginFails(t *testing.T) {
}

func TestVaultAuthMethod_Login_ReturnsLeaseDuration(t *testing.T) {
expectedLeaseDuration := 60
auth := vaultAuthMethod[any, api.AuthMethod]{
methodFactory: func(_ any) (api.AuthMethod, error) {
return authMethodStub{leaseDuration: expectedLeaseDuration}, nil
expectedSecret := &api.Secret{
Auth: &api.SecretAuth{
ClientToken: "test",
},
}

auth := vaultAuthMethodImpl{
authMethodFactoryStub{
method: authMethodStub{secret: expectedSecret},
},
}

leaseDuration, err := auth.Login(context.Background(), &api.Client{})
authSecret, err := auth.Login(context.Background(), &api.Client{})

assert.NoError(t, err, "Login failed unexpectedly")
assert.Equal(t, time.Duration(expectedLeaseDuration)*time.Second, leaseDuration)
assert.Equal(t, expectedSecret, authSecret)
}

type authMethodFactoryStub struct {
method api.AuthMethod
createErr error
}

func (stub authMethodFactoryStub) createAuthMethod() (api.AuthMethod, error) {
if stub.createErr != nil {
return nil, stub.createErr
}
return stub.method, nil
}

type authMethodStub struct {
loginError error
leaseDuration int
loginError error
secret *api.Secret
}

func (stub authMethodStub) Login(_ context.Context, _ *api.Client) (*api.Secret, error) {
if stub.loginError != nil {
return nil, stub.loginError
}

return &api.Secret{
Auth: &api.SecretAuth{
ClientToken: "Test",
LeaseDuration: stub.leaseDuration,
},
}, nil
return stub.secret, nil
}
Loading

0 comments on commit e2deaf7

Please sign in to comment.