Skip to content

Commit

Permalink
mdbx: race conditions in MdbxKV.Close (#8409) (#9244)
Browse files Browse the repository at this point in the history
In the previous code WaitGroup db.wg.Add(), Wait() and db.closed were
not treated in sync. In particular, it was theoretically possible to
first check closed, then set closed and Wait, and then call wg.Add()
while waiting (leading to WaitGroup panic).
In theory it was also possible that db.env.BeginTxn() is called on a
closed or nil db.env, because db.wg.Add() was called only after BeginTxn
(db.wg.Wait() could already return).

WaitGroup is replaced with a Cond variable.
Now it is not possible to increase the active transactions count on a
closed database. It is also not possible to call BeginTxn on a closed
database.
  • Loading branch information
battlmonstr authored Jan 17, 2024
1 parent 5e5d849 commit 1914b52
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 15 deletions.
78 changes: 65 additions & 13 deletions erigon-lib/kv/mdbx/kv_mdbx.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"github.com/c2h5oh/datasize"
"github.com/erigontech/mdbx-go/mdbx"
stack2 "github.com/go-stack/stack"
"github.com/ledgerwatch/erigon-lib/mmap"
"github.com/ledgerwatch/log/v3"
"golang.org/x/exp/maps"
"golang.org/x/sync/semaphore"
Expand All @@ -44,6 +43,7 @@ import (
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/iter"
"github.com/ledgerwatch/erigon-lib/kv/order"
"github.com/ledgerwatch/erigon-lib/mmap"
)

const NonExistingDBI kv.DBI = 999_999_999
Expand Down Expand Up @@ -385,15 +385,20 @@ func (opts MdbxOpts) Open(ctx context.Context) (kv.RwDB, error) {
targetSemCount := int64(runtime.GOMAXPROCS(-1) * 16)
opts.roTxsLimiter = semaphore.NewWeighted(targetSemCount) // 1 less than max to allow unlocking to happen
}

txsCountMutex := &sync.Mutex{}

db := &MdbxKV{
opts: opts,
env: env,
log: opts.log,
wg: &sync.WaitGroup{},
buckets: kv.TableCfg{},
txSize: dirtyPagesLimit * opts.pageSize,
roTxsLimiter: opts.roTxsLimiter,

txsCountMutex: txsCountMutex,
txsAllDoneOnCloseCond: sync.NewCond(txsCountMutex),

leakDetector: dbg.NewLeakDetector("db."+opts.label.String(), dbg.SlowTx()),
}

Expand Down Expand Up @@ -457,14 +462,17 @@ func (opts MdbxOpts) MustOpen() kv.RwDB {
type MdbxKV struct {
log log.Logger
env *mdbx.Env
wg *sync.WaitGroup
buckets kv.TableCfg
roTxsLimiter *semaphore.Weighted // does limit amount of concurrent Ro transactions - in most casess runtime.NumCPU() is good value for this channel capacity - this channel can be shared with other components (like Decompressor)
opts MdbxOpts
txSize uint64
closed atomic.Bool
path string

txsCount uint
txsCountMutex *sync.Mutex
txsAllDoneOnCloseCond *sync.Cond

leakDetector *dbg.LeakDetector
}

Expand Down Expand Up @@ -507,13 +515,53 @@ func (db *MdbxKV) openDBIs(buckets []string) error {
})
}

func (db *MdbxKV) trackTxBegin() bool {
db.txsCountMutex.Lock()
defer db.txsCountMutex.Unlock()

isOpen := !db.closed.Load()
if isOpen {
db.txsCount++
}
return isOpen
}

func (db *MdbxKV) hasTxsAllDoneAndClosed() bool {
return (db.txsCount == 0) && db.closed.Load()
}

func (db *MdbxKV) trackTxEnd() {
db.txsCountMutex.Lock()
defer db.txsCountMutex.Unlock()

if db.txsCount > 0 {
db.txsCount--
} else {
panic("MdbxKV: unmatched trackTxEnd")
}

if db.hasTxsAllDoneAndClosed() {
db.txsAllDoneOnCloseCond.Signal()
}
}

func (db *MdbxKV) waitTxsAllDoneOnClose() {
db.txsCountMutex.Lock()
defer db.txsCountMutex.Unlock()

for !db.hasTxsAllDoneAndClosed() {
db.txsAllDoneOnCloseCond.Wait()
}
}

// Close closes db
// All transactions must be closed before closing the database.
func (db *MdbxKV) Close() {
if ok := db.closed.CompareAndSwap(false, true); !ok {
return
}
db.wg.Wait()
db.waitTxsAllDoneOnClose()

db.env.Close()
db.env = nil

Expand All @@ -526,10 +574,6 @@ func (db *MdbxKV) Close() {
}

func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) {
if db.closed.Load() {
return nil, fmt.Errorf("db closed")
}

// don't try to acquire if the context is already done
select {
case <-ctx.Done():
Expand All @@ -538,8 +582,13 @@ func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) {
// otherwise carry on
}

if !db.trackTxBegin() {
return nil, fmt.Errorf("db closed")
}

// will return nil err if context is cancelled (may appear to acquire the semaphore)
if semErr := db.roTxsLimiter.Acquire(ctx, 1); semErr != nil {
db.trackTxEnd()
return nil, semErr
}

Expand All @@ -548,14 +597,15 @@ func (db *MdbxKV) BeginRo(ctx context.Context) (txn kv.Tx, err error) {
// on error, or if there is whatever reason that we don't return a tx,
// we need to free up the limiter slot, otherwise it could lead to deadlocks
db.roTxsLimiter.Release(1)
db.trackTxEnd()
}
}()

tx, err := db.env.BeginTxn(nil, mdbx.Readonly)
if err != nil {
return nil, fmt.Errorf("%w, label: %s, trace: %s", err, db.opts.label.String(), stack2.Trace().String())
}
db.wg.Add(1)

return &MdbxTx{
ctx: ctx,
db: db,
Expand All @@ -579,16 +629,18 @@ func (db *MdbxKV) beginRw(ctx context.Context, flags uint) (txn kv.RwTx, err err
default:
}

if db.closed.Load() {
if !db.trackTxBegin() {
return nil, fmt.Errorf("db closed")
}

runtime.LockOSThread()
tx, err := db.env.BeginTxn(nil, flags)
if err != nil {
runtime.UnlockOSThread() // unlock only in case of error. normal flow is "defer .Rollback()"
db.trackTxEnd()
return nil, fmt.Errorf("%w, lable: %s, trace: %s", err, db.opts.label.String(), stack2.Trace().String())
}
db.wg.Add(1)

return &MdbxTx{
db: db,
tx: tx,
Expand Down Expand Up @@ -830,7 +882,7 @@ func (tx *MdbxTx) Commit() error {
}
defer func() {
tx.tx = nil
tx.db.wg.Done()
tx.db.trackTxEnd()
if tx.readOnly {
tx.db.roTxsLimiter.Release(1)
} else {
Expand Down Expand Up @@ -881,7 +933,7 @@ func (tx *MdbxTx) Rollback() {
}
defer func() {
tx.tx = nil
tx.db.wg.Done()
tx.db.trackTxEnd()
if tx.readOnly {
tx.db.roTxsLimiter.Release(1)
} else {
Expand Down
128 changes: 126 additions & 2 deletions erigon-lib/kv/mdbx/kv_mdbx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ package mdbx

import (
"context"
"sync/atomic"
"testing"
"time"

"github.com/c2h5oh/datasize"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/order"
"github.com/ledgerwatch/log/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/order"
)

func BaseCase(t *testing.T) (kv.RwDB, kv.RwTx, kv.RwCursorDupSort) {
Expand Down Expand Up @@ -773,3 +776,124 @@ func TestAutoConversionSeekBothRange(t *testing.T) {
require.NoError(t, err)
assert.Nil(t, v)
}

func TestBeginRoAfterClose(t *testing.T) {
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
db.Close()
_, err := db.BeginRo(context.Background())
require.ErrorContains(t, err, "closed")
}

func TestBeginRwAfterClose(t *testing.T) {
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
db.Close()
_, err := db.BeginRw(context.Background())
require.ErrorContains(t, err, "closed")
}

func TestBeginRoWithDoneContext(t *testing.T) {
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := db.BeginRo(ctx)
require.ErrorIs(t, err, context.Canceled)
}

func TestBeginRwWithDoneContext(t *testing.T) {
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := db.BeginRw(ctx)
require.ErrorIs(t, err, context.Canceled)
}

func testCloseWaitsAfterTxBegin(
t *testing.T,
count int,
txBeginFunc func(kv.RwDB) (kv.StatelessReadTx, error),
txEndFunc func(kv.StatelessReadTx) error,
) {
t.Helper()
db := NewMDBX(log.New()).InMem(t.TempDir()).MustOpen()
var txs []kv.StatelessReadTx
for i := 0; i < count; i++ {
tx, err := txBeginFunc(db)
require.Nil(t, err)
txs = append(txs, tx)
}

isClosed := &atomic.Bool{}
closeDone := make(chan struct{})

go func() {
db.Close()
isClosed.Store(true)
close(closeDone)
}()

for _, tx := range txs {
// arbitrary delay to give db.Close() a chance to exit prematurely
time.Sleep(time.Millisecond * 20)
assert.False(t, isClosed.Load())

err := txEndFunc(tx)
require.Nil(t, err)
}

<-closeDone
assert.True(t, isClosed.Load())
}

func TestCloseWaitsAfterTxBegin(t *testing.T) {
ctx := context.Background()
t.Run("BeginRoAndCommit", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
1,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) },
func(tx kv.StatelessReadTx) error { return tx.Commit() },
)
})
t.Run("BeginRoAndCommit3", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
3,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) },
func(tx kv.StatelessReadTx) error { return tx.Commit() },
)
})
t.Run("BeginRoAndRollback", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
1,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) },
func(tx kv.StatelessReadTx) error { tx.Rollback(); return nil },
)
})
t.Run("BeginRoAndRollback3", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
3,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRo(ctx) },
func(tx kv.StatelessReadTx) error { tx.Rollback(); return nil },
)
})
t.Run("BeginRwAndCommit", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
1,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRw(ctx) },
func(tx kv.StatelessReadTx) error { return tx.Commit() },
)
})
t.Run("BeginRwAndRollback", func(t *testing.T) {
testCloseWaitsAfterTxBegin(
t,
1,
func(db kv.RwDB) (kv.StatelessReadTx, error) { return db.BeginRw(ctx) },
func(tx kv.StatelessReadTx) error { tx.Rollback(); return nil },
)
})
}

0 comments on commit 1914b52

Please sign in to comment.