Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport 1.7.x: Add ability to customize some timeouts in MongoDB database plugin #11637

Merged
merged 1 commit into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions builtin/logical/database/path_rotate_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ func (b *databaseBackend) pathRotateRootCredentialsUpdate() framework.OperationF
return nil, err
}

// Take out the backend lock since we are swapping out the connection
b.Lock()
defer b.Unlock()

// Take the write lock on the instance
dbi.Lock()
defer dbi.Unlock()

defer func() {
// Close the plugin
dbi.closed = true
Expand All @@ -88,14 +96,6 @@ func (b *databaseBackend) pathRotateRootCredentialsUpdate() framework.OperationF
delete(b.connections, name)
}()

// Take out the backend lock since we are swapping out the connection
b.Lock()
defer b.Unlock()

// Take the write lock on the instance
dbi.Lock()
defer dbi.Unlock()

// Generate new credentials
oldPassword := config.ConnectionDetails["password"].(string)
newPassword, err := dbi.database.GeneratePassword(ctx, b.System(), config.PasswordPolicy)
Expand Down
9 changes: 9 additions & 0 deletions changelog/11600.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
```release-note:improvement
secrets/database/mongodb: Add ability to customize `SocketTimeout`, `ConnectTimeout`, and `ServerSelectionTimeout`
```
```release-note:improvement
secrets/database/mongodb: Increased throughput by allowing for multiple request threads to simultaneously update users in MongoDB
```
```release-note:bug
secrets/database: Fixed minor race condition when rotate-root is called
```
106 changes: 94 additions & 12 deletions plugins/database/mongodb/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/mitchellh/mapstructure"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
Expand All @@ -31,6 +32,10 @@ type mongoDBConnectionProducer struct {
TLSCertificateKeyData []byte `json:"tls_certificate_key" structs:"-" mapstructure:"tls_certificate_key"`
TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"`

SocketTimeout time.Duration `json:"socket_timeout" structs:"-" mapstructure:"socket_timeout"`
ConnectTimeout time.Duration `json:"connect_timeout" structs:"-" mapstructure:"connect_timeout"`
ServerSelectionTimeout time.Duration `json:"server_selection_timeout" structs:"-" mapstructure:"server_selection_timeout"`

Initialized bool
RawConfig map[string]interface{}
Type string
Expand All @@ -48,15 +53,47 @@ type writeConcern struct {
J bool // Sync via the journal if present
}

func (c *mongoDBConnectionProducer) loadConfig(cfg map[string]interface{}) error {
err := mapstructure.WeakDecode(cfg, c)
if err != nil {
return err
}

if len(c.ConnectionURL) == 0 {
return fmt.Errorf("connection_url cannot be empty")
}

if c.SocketTimeout < 0 {
return fmt.Errorf("socket_timeout must be >= 0")
}
if c.ConnectTimeout < 0 {
return fmt.Errorf("connect_timeout must be >= 0")
}
if c.ServerSelectionTimeout < 0 {
return fmt.Errorf("server_selection_timeout must be >= 0")
}

opts, err := c.makeClientOpts()
if err != nil {
return err
}

c.clientOptions = opts

return nil
}

// Connection creates or returns an existing a database connection. If the session fails
// on a ping check, the session will be closed and then re-created.
// This method does not lock the mutex and it is intended that this is the callers
// responsibility.
func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
// This method does locks the mutex on its own.
func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (*mongo.Client, error) {
if !c.Initialized {
return nil, connutil.ErrNotInitialized
}

c.Mutex.Lock()
defer c.Mutex.Unlock()

if c.client != nil {
if err := c.client.Ping(ctx, readpref.Primary()); err == nil {
return c.client, nil
Expand All @@ -65,23 +102,22 @@ func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (interface{}
_ = c.client.Disconnect(ctx)
}

connURL := c.getConnectionURL()
client, err := createClient(ctx, connURL, c.clientOptions)
client, err := c.createClient(ctx)
if err != nil {
return nil, err
}
c.client = client
return c.client, nil
}

func createClient(ctx context.Context, connURL string, clientOptions *options.ClientOptions) (client *mongo.Client, err error) {
if clientOptions == nil {
clientOptions = options.Client()
func (c *mongoDBConnectionProducer) createClient(ctx context.Context) (client *mongo.Client, err error) {
if !c.Initialized {
return nil, fmt.Errorf("failed to create client: connection producer is not initialized")
}
clientOptions.SetSocketTimeout(1 * time.Minute)
clientOptions.SetConnectTimeout(1 * time.Minute)

client, err = mongo.Connect(ctx, options.MergeClientOptions(options.Client().ApplyURI(connURL), clientOptions))
if c.clientOptions == nil {
return nil, fmt.Errorf("missing client options")
}
client, err = mongo.Connect(ctx, options.MergeClientOptions(options.Client().ApplyURI(c.getConnectionURL()), c.clientOptions))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -120,6 +156,26 @@ func (c *mongoDBConnectionProducer) getConnectionURL() (connURL string) {
return connURL
}

func (c *mongoDBConnectionProducer) makeClientOpts() (*options.ClientOptions, error) {
writeOpts, err := c.getWriteConcern()
if err != nil {
return nil, err
}

authOpts, err := c.getTLSAuth()
if err != nil {
return nil, err
}

timeoutOpts, err := c.timeoutOpts()
if err != nil {
return nil, err
}

opts := options.MergeClientOptions(writeOpts, authOpts, timeoutOpts)
return opts, nil
}

func (c *mongoDBConnectionProducer) getWriteConcern() (opts *options.ClientOptions, err error) {
if c.WriteConcern == "" {
return nil, nil
Expand Down Expand Up @@ -206,3 +262,29 @@ func (c *mongoDBConnectionProducer) getTLSAuth() (opts *options.ClientOptions, e
opts.SetTLSConfig(tlsConfig)
return opts, nil
}

func (c *mongoDBConnectionProducer) timeoutOpts() (opts *options.ClientOptions, err error) {
opts = options.Client()

if c.SocketTimeout < 0 {
return nil, fmt.Errorf("socket_timeout must be >= 0")
}

if c.SocketTimeout == 0 {
opts.SetSocketTimeout(1 * time.Minute)
} else {
opts.SetSocketTimeout(c.SocketTimeout)
}

if c.ConnectTimeout == 0 {
opts.SetConnectTimeout(1 * time.Minute)
} else {
opts.SetConnectTimeout(c.ConnectTimeout)
}

if c.ServerSelectionTimeout != 0 {
opts.SetServerSelectionTimeout(c.ServerSelectionTimeout)
}

return opts, nil
}
2 changes: 1 addition & 1 deletion plugins/database/mongodb/connection_producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ net:
"connectionStatus": 1,
}

client, err := mongo.getConnection(ctx)
client, err := mongo.Connection(ctx)
if err != nil {
t.Fatalf("Unable to make connection to Mongo: %s", err)
}
Expand Down
59 changes: 15 additions & 44 deletions plugins/database/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ import (
"io"
"strings"

log "github.com/hashicorp/go-hclog"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/helper/template"

dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/mitchellh/mapstructure"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
Expand Down Expand Up @@ -57,15 +55,6 @@ func (m *MongoDB) Type() (string, error) {
return mongoDBTypeName, nil
}

func (m *MongoDB) getConnection(ctx context.Context) (*mongo.Client, error) {
client, err := m.Connection(ctx)
if err != nil {
return nil, err
}

return client.(*mongo.Client), nil
}

func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
m.Lock()
defer m.Unlock()
Expand All @@ -91,41 +80,27 @@ func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest
return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template: %w", err)
}

err = mapstructure.WeakDecode(req.Config, m.mongoDBConnectionProducer)
if err != nil {
return dbplugin.InitializeResponse{}, err
}

if len(m.ConnectionURL) == 0 {
return dbplugin.InitializeResponse{}, fmt.Errorf("connection_url cannot be empty-mongo fail")
}

writeOpts, err := m.getWriteConcern()
if err != nil {
return dbplugin.InitializeResponse{}, err
}

authOpts, err := m.getTLSAuth()
err = m.mongoDBConnectionProducer.loadConfig(req.Config)
if err != nil {
return dbplugin.InitializeResponse{}, err
}

m.clientOptions = options.MergeClientOptions(writeOpts, authOpts)

// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
m.Initialized = true

if req.VerifyConnection {
_, err := m.Connection(ctx)
client, err := m.mongoDBConnectionProducer.createClient(ctx)
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to verify connection: %w", err)
}

err = m.client.Ping(ctx, readpref.Primary())
err = client.Ping(ctx, readpref.Primary())
if err != nil {
_ = client.Disconnect(ctx) // Try to prevent any sort of resource leak
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to verify connection: %w", err)
}
m.mongoDBConnectionProducer.client = client
}

resp := dbplugin.InitializeResponse{
Expand All @@ -135,10 +110,6 @@ func (m *MongoDB) Initialize(ctx context.Context, req dbplugin.InitializeRequest
}

func (m *MongoDB) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
// Grab the lock
m.Lock()
defer m.Unlock()

if len(req.Statements.Commands) == 0 {
return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
}
Expand Down Expand Up @@ -189,9 +160,6 @@ func (m *MongoDB) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest
}

func (m *MongoDB) changeUserPassword(ctx context.Context, username, password string) error {
m.Lock()
defer m.Unlock()

connURL := m.getConnectionURL()
cs, err := connstring.Parse(connURL)
if err != nil {
Expand All @@ -218,9 +186,6 @@ func (m *MongoDB) changeUserPassword(ctx context.Context, username, password str
}

func (m *MongoDB) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
m.Lock()
defer m.Unlock()

// If no revocation statements provided, pass in empty JSON
var revocationStatement string
switch len(req.Statements.Commands) {
Expand Down Expand Up @@ -251,14 +216,20 @@ func (m *MongoDB) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest
}

err = m.runCommandWithRetry(ctx, db, dropUserCmd)
cErr, ok := err.(mongo.CommandError)
if ok && cErr.Name == "UserNotFound" { // User already removed, don't retry needlessly
log.Default().Warn("MongoDB user was deleted prior to lease revocation", "user", req.Username)
return dbplugin.DeleteUserResponse{}, nil
}

return dbplugin.DeleteUserResponse{}, err
}

// runCommandWithRetry runs a command and retries once more if there's a failure
// on the first attempt. This should be called with the lock held
func (m *MongoDB) runCommandWithRetry(ctx context.Context, db string, cmd interface{}) error {
// Get the client
client, err := m.getConnection(ctx)
client, err := m.Connection(ctx)
if err != nil {
return err
}
Expand All @@ -273,7 +244,7 @@ func (m *MongoDB) runCommandWithRetry(ctx context.Context, db string, cmd interf
return nil
case err == io.EOF, strings.Contains(err.Error(), "EOF"):
// Call getConnection to reset and retry query if we get an EOF error on first attempt.
client, err = m.getConnection(ctx)
client, err = m.Connection(ctx)
if err != nil {
return err
}
Expand Down