Skip to content

Commit

Permalink
Add ability to customize some timeouts in MongoDB database plugin (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
pcman312 committed May 17, 2021
1 parent 3c5754c commit 5ac64ea
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 65 deletions.
16 changes: 8 additions & 8 deletions builtin/logical/database/path_rotate_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ func (b *databaseBackend) pathRotateRootCredentialsUpdate() framework.OperationF
return nil, err
}

// Take out the backend lock since we are swapping out the connection
b.Lock()
defer b.Unlock()

// Take the write lock on the instance
dbi.Lock()
defer dbi.Unlock()

defer func() {
// Close the plugin
dbi.closed = true
Expand All @@ -88,14 +96,6 @@ func (b *databaseBackend) pathRotateRootCredentialsUpdate() framework.OperationF
delete(b.connections, name)
}()

// Take out the backend lock since we are swapping out the connection
b.Lock()
defer b.Unlock()

// Take the write lock on the instance
dbi.Lock()
defer dbi.Unlock()

// Generate new credentials
oldPassword := config.ConnectionDetails["password"].(string)
newPassword, err := dbi.database.GeneratePassword(ctx, b.System(), config.PasswordPolicy)
Expand Down
9 changes: 9 additions & 0 deletions changelog/11600.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
```release-note:improvement
secrets/database/mongodb: Add ability to customize `SocketTimeout`, `ConnectTimeout`, and `ServerSelectionTimeout`
```
```release-note:improvement
secrets/database/mongodb: Increased throughput by allowing for multiple request threads to simultaneously update users in MongoDB
```
```release-note:bug
secrets/database: Fixed minor race condition when rotate-root is called
```
106 changes: 94 additions & 12 deletions plugins/database/mongodb/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/mitchellh/mapstructure"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
Expand All @@ -31,6 +32,10 @@ type mongoDBConnectionProducer struct {
TLSCertificateKeyData []byte `json:"tls_certificate_key" structs:"-" mapstructure:"tls_certificate_key"`
TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"`

SocketTimeout time.Duration `json:"socket_timeout" structs:"-" mapstructure:"socket_timeout"`
ConnectTimeout time.Duration `json:"connect_timeout" structs:"-" mapstructure:"connect_timeout"`
ServerSelectionTimeout time.Duration `json:"server_selection_timeout" structs:"-" mapstructure:"server_selection_timeout"`

Initialized bool
RawConfig map[string]interface{}
Type string
Expand All @@ -48,15 +53,47 @@ type writeConcern struct {
J bool // Sync via the journal if present
}

func (c *mongoDBConnectionProducer) loadConfig(cfg map[string]interface{}) error {
err := mapstructure.WeakDecode(cfg, c)
if err != nil {
return err
}

if len(c.ConnectionURL) == 0 {
return fmt.Errorf("connection_url cannot be empty")
}

if c.SocketTimeout < 0 {
return fmt.Errorf("socket_timeout must be >= 0")
}
if c.ConnectTimeout < 0 {
return fmt.Errorf("connect_timeout must be >= 0")
}
if c.ServerSelectionTimeout < 0 {
return fmt.Errorf("server_selection_timeout must be >= 0")
}

opts, err := c.makeClientOpts()
if err != nil {
return err
}

c.clientOptions = opts

return nil
}

// Connection creates or returns an existing a database connection. If the session fails
// on a ping check, the session will be closed and then re-created.
// This method does not lock the mutex and it is intended that this is the callers
// responsibility.
func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
// This method does locks the mutex on its own.
func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (*mongo.Client, error) {
if !c.Initialized {
return nil, connutil.ErrNotInitialized
}

c.Mutex.Lock()
defer c.Mutex.Unlock()

if c.client != nil {
if err := c.client.Ping(ctx, readpref.Primary()); err == nil {
return c.client, nil
Expand All @@ -65,23 +102,22 @@ func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (interface{}
_ = c.client.Disconnect(ctx)
}

connURL := c.getConnectionURL()
client, err := createClient(ctx, connURL, c.clientOptions)
client, err := c.createClient(ctx)
if err != nil {
return nil, err
}
c.client = client
return c.client, nil
}

func createClient(ctx context.Context, connURL string, clientOptions *options.ClientOptions) (client *mongo.Client, err error) {
if clientOptions == nil {
clientOptions = options.Client()
func (c *mongoDBConnectionProducer) createClient(ctx context.Context) (client *mongo.Client, err error) {
if !c.Initialized {
return nil, fmt.Errorf("failed to create client: connection producer is not initialized")
}
clientOptions.SetSocketTimeout(1 * time.Minute)
clientOptions.SetConnectTimeout(1 * time.Minute)

client, err = mongo.Connect(ctx, options.MergeClientOptions(options.Client().ApplyURI(connURL), clientOptions))
if c.clientOptions == nil {
return nil, fmt.Errorf("missing client options")
}
client, err = mongo.Connect(ctx, options.MergeClientOptions(options.Client().ApplyURI(c.getConnectionURL()), c.clientOptions))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -120,6 +156,26 @@ func (c *mongoDBConnectionProducer) getConnectionURL() (connURL string) {
return connURL
}

func (c *mongoDBConnectionProducer) makeClientOpts() (*options.ClientOptions, error) {
writeOpts, err := c.getWriteConcern()
if err != nil {
return nil, err
}

authOpts, err := c.getTLSAuth()
if err != nil {
return nil, err
}

timeoutOpts, err := c.timeoutOpts()
if err != nil {
return nil, err
}

opts := options.MergeClientOptions(writeOpts, authOpts, timeoutOpts)
return opts, nil
}

func (c *mongoDBConnectionProducer) getWriteConcern() (opts *options.ClientOptions, err error) {
if c.WriteConcern == "" {
return nil, nil
Expand Down Expand Up @@ -206,3 +262,29 @@ func (c *mongoDBConnectionProducer) getTLSAuth() (opts *options.ClientOptions, e
opts.SetTLSConfig(tlsConfig)
return opts, nil
}

func (c *mongoDBConnectionProducer) timeoutOpts() (opts *options.ClientOptions, err error) {
opts = options.Client()

if c.SocketTimeout < 0 {
return nil, fmt.Errorf("socket_timeout must be >= 0")
}

if c.SocketTimeout == 0 {
opts.SetSocketTimeout(1 * time.Minute)
} else {
opts.SetSocketTimeout(c.SocketTimeout)
}

if c.ConnectTimeout == 0 {
opts.SetConnectTimeout(1 * time.Minute)
} else {
opts.SetConnectTimeout(c.ConnectTimeout)
}

if c.ServerSelectionTimeout != 0 {
opts.SetServerSelectionTimeout(c.ServerSelectionTimeout)
}

return opts, nil
}
2 changes: 1 addition & 1 deletion plugins/database/mongodb/connection_producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ net:
"connectionStatus": 1,
}

client, err := mongo.getConnection(ctx)
client, err := mongo.Connection(ctx)
if err != nil {
t.Fatalf("Unable to make connection to Mongo: %s", err)
}
Expand Down
59 changes: 15 additions & 44 deletions plugins/database/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ import (
"io"
"strings"

log "github.com/hashicorp/go-hclog"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/helper/template"

dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/mitchellh/mapstructure"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
Expand Down Expand Up @@ -57,15 +55,6 @@ func (m *MongoDB) Type() (string, error) {
return mongoDBTypeName, nil
}

func (m *MongoDB) getConnection(ctx context.Context) (*mongo.Client, error) {
client, err := m.Connection(ctx)
if err != nil {
return nil, err
}

return client.(*mongo.Client), nil
}

func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
m.Lock()
defer m.Unlock()
Expand All @@ -91,41 +80,27 @@ func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest
return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template: %w", err)
}

err = mapstructure.WeakDecode(req.Config, m.mongoDBConnectionProducer)
if err != nil {
return dbplugin.InitializeResponse{}, err
}

if len(m.ConnectionURL) == 0 {
return dbplugin.InitializeResponse{}, fmt.Errorf("connection_url cannot be empty-mongo fail")
}

writeOpts, err := m.getWriteConcern()
if err != nil {
return dbplugin.InitializeResponse{}, err
}

authOpts, err := m.getTLSAuth()
err = m.mongoDBConnectionProducer.loadConfig(req.Config)
if err != nil {
return dbplugin.InitializeResponse{}, err
}

m.clientOptions = options.MergeClientOptions(writeOpts, authOpts)

// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
m.Initialized = true

if req.VerifyConnection {
_, err := m.Connection(ctx)
client, err := m.mongoDBConnectionProducer.createClient(ctx)
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to verify connection: %w", err)
}

err = m.client.Ping(ctx, readpref.Primary())
err = client.Ping(ctx, readpref.Primary())
if err != nil {
_ = client.Disconnect(ctx) // Try to prevent any sort of resource leak
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to verify connection: %w", err)
}
m.mongoDBConnectionProducer.client = client
}

resp := dbplugin.InitializeResponse{
Expand All @@ -135,10 +110,6 @@ func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest
}

func (m *MongoDB) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
// Grab the lock
m.Lock()
defer m.Unlock()

if len(req.Statements.Commands) == 0 {
return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
}
Expand Down Expand Up @@ -189,9 +160,6 @@ func (m *MongoDB) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest
}

func (m *MongoDB) changeUserPassword(ctx context.Context, username, password string) error {
m.Lock()
defer m.Unlock()

connURL := m.getConnectionURL()
cs, err := connstring.Parse(connURL)
if err != nil {
Expand All @@ -218,9 +186,6 @@ func (m *MongoDB) changeUserPassword(ctx context.Context, username, password str
}

func (m *MongoDB) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
m.Lock()
defer m.Unlock()

// If no revocation statements provided, pass in empty JSON
var revocationStatement string
switch len(req.Statements.Commands) {
Expand Down Expand Up @@ -251,14 +216,20 @@ func (m *MongoDB) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest
}

err = m.runCommandWithRetry(ctx, db, dropUserCmd)
cErr, ok := err.(mongo.CommandError)
if ok && cErr.Name == "UserNotFound" { // User already removed, don't retry needlessly
log.Default().Warn("MongoDB user was deleted prior to lease revocation", "user", req.Username)
return dbplugin.DeleteUserResponse{}, nil
}

return dbplugin.DeleteUserResponse{}, err
}

// runCommandWithRetry runs a command and retries once more if there's a failure
// on the first attempt. This should be called with the lock held
func (m *MongoDB) runCommandWithRetry(ctx context.Context, db string, cmd interface{}) error {
// Get the client
client, err := m.getConnection(ctx)
client, err := m.Connection(ctx)
if err != nil {
return err
}
Expand All @@ -273,7 +244,7 @@ func (m *MongoDB) runCommandWithRetry(ctx context.Context, db string, cmd interf
return nil
case err == io.EOF, strings.Contains(err.Error(), "EOF"):
// Call getConnection to reset and retry query if we get an EOF error on first attempt.
client, err = m.getConnection(ctx)
client, err = m.Connection(ctx)
if err != nil {
return err
}
Expand Down

0 comments on commit 5ac64ea

Please sign in to comment.