From 9b946de9c71b279d8c5861b3283452799ea2f453 Mon Sep 17 00:00:00 2001 From: Patrik Date: Thu, 19 May 2022 16:13:23 +0200 Subject: [PATCH] feat: expose connection transactions with context and options --- connection.go | 20 +++++++++++++------- connection_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/connection.go b/connection.go index f8b81d31..ecb3e87c 100644 --- a/connection.go +++ b/connection.go @@ -2,6 +2,7 @@ package pop import ( "context" + "database/sql" "errors" "fmt" "sync/atomic" @@ -185,21 +186,26 @@ func (c *Connection) Rollback(fn func(tx *Connection)) error { // NewTransaction starts a new transaction on the connection func (c *Connection) NewTransaction() (*Connection, error) { + return c.NewTransactionContextOptions(c.Context(), nil) +} + +// NewTransactionContext starts a new transaction on the connection using the provided context +func (c *Connection) NewTransactionContext(ctx context.Context) (*Connection, error) { + return c.NewTransactionContextOptions(ctx, nil) +} + +// NewTransactionContextOptions starts a new transaction on the connection using the provided context and transaction options +func (c *Connection) NewTransactionContextOptions(ctx context.Context, options *sql.TxOptions) (*Connection, error) { var cn *Connection if c.TX == nil { - tx, err := c.Store.Transaction() + tx, err := c.Store.TransactionContextOptions(ctx, options) if err != nil { return cn, fmt.Errorf("couldn't start a new transaction: %w", err) } - var store store = tx - // Rewrap the store if it was a context store - if cs, ok := c.Store.(contextStore); ok { - store = contextStore{store: store, ctx: cs.ctx} - } cn = &Connection{ ID: randx.String(30), - Store: store, + Store: contextStore{store: tx, ctx: ctx}, Dialect: c.Dialect, TX: tx, } diff --git a/connection_test.go b/connection_test.go index dfdb9756..a923a281 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1,8 +1,10 @@ +//go:build sqlite // +build sqlite package pop import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -52,3 +54,46 @@ func Test_Connection_Open_BadDriver(t *testing.T) { err = c.Open() r.Error(err) } + +func Test_Connection_Transaction(t *testing.T) { + r := require.New(t) + ctx := context.WithValue(context.Background(), "test", "test") + + c, err := NewConnection(&ConnectionDetails{ + URL: "sqlite://file::memory:?_fk=true", + }) + r.NoError(err) + r.NoError(c.Open()) + c = c.WithContext(ctx) + + t.Run("func=NewTransaction", func(t *testing.T) { + r := require.New(t) + tx, err := c.NewTransaction() + r.NoError(err) + + // has transaction and context + r.NotNil(tx.TX) + r.Nil(c.TX) + r.Equal(ctx, tx.Context()) + + // does not start a new transaction + ntx, err := tx.NewTransaction() + r.Equal(tx, ntx) + + r.NoError(tx.TX.Rollback()) + }) + + t.Run("func=NewTransactionContext", func(t *testing.T) { + r := require.New(t) + nctx := context.WithValue(ctx, "nested", "test") + tx, err := c.NewTransactionContext(nctx) + r.NoError(err) + + // has transaction and context + r.NotNil(tx.TX) + r.Nil(c.TX) + r.Equal(nctx, tx.Context()) + + r.NoError(tx.TX.Rollback()) + }) +}