Skip to content

Commit

Permalink
fix the Postgres database transaction cannot continue to execute afte…
Browse files Browse the repository at this point in the history
…r failure (#50)
  • Loading branch information
mayswind committed Feb 16, 2025
1 parent 8f55bd0 commit a9a37b0
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 8 deletions.
24 changes: 23 additions & 1 deletion pkg/datastore/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"xorm.io/xorm"

"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/settings"
)

// Database represents a database instance
type Database struct {
engineGroup *xorm.EngineGroup
databaseType string
engineGroup *xorm.EngineGroup
}

// NewSession starts a new session with the specified context
Expand Down Expand Up @@ -41,3 +43,23 @@ func (db *Database) DoTransaction(c core.Context, fn func(sess *xorm.Session) er

return nil
}

// SetSavePoint sets a save point in the current transaction for Postgres
func (db *Database) SetSavePoint(sess *xorm.Session, savePointName string) error {
if db.databaseType == settings.PostgresDbType {
_, err := sess.Exec("SAVEPOINT " + savePointName)
return err
}

return nil
}

// RollbackToSavePoint rolls back to the specified save point in the current transaction for Postgres
func (db *Database) RollbackToSavePoint(sess *xorm.Session, savePointName string) error {
if db.databaseType == settings.PostgresDbType {
_, err := sess.Exec("ROLLBACK TO SAVEPOINT " + savePointName)
return err
}

return nil
}
3 changes: 2 additions & 1 deletion pkg/datastore/datastore_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ func initializeDatabase(dbConfig *settings.DatabaseConfig) (*Database, error) {
engineGroup.SetConnMaxLifetime(time.Duration(dbConfig.ConnectionMaxLifeTime) * time.Second)

return &Database{
engineGroup: engineGroup,
databaseType: dbConfig.DatabaseType,
engineGroup: engineGroup,
}, nil
}

Expand Down
21 changes: 20 additions & 1 deletion pkg/services/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/datastore"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/models"
"github.com/mayswind/ezbookkeeping/pkg/utils"
"github.com/mayswind/ezbookkeeping/pkg/uuid"
Expand Down Expand Up @@ -270,7 +271,9 @@ func (s *AccountService) CreateAccounts(c core.Context, mainAccount *models.Acco
}
}

return s.UserDataDB(mainAccount.Uid).DoTransaction(c, func(sess *xorm.Session) error {
userDataDb := s.UserDataDB(mainAccount.Uid)

return userDataDb.DoTransaction(c, func(sess *xorm.Session) error {
for i := 0; i < len(allAccounts); i++ {
account := allAccounts[i]
_, err := sess.Insert(account)
Expand All @@ -282,9 +285,25 @@ func (s *AccountService) CreateAccounts(c core.Context, mainAccount *models.Acco

for i := 0; i < len(allInitTransactions); i++ {
transaction := allInitTransactions[i]

insertTransactionSavePointName := "insert_transaction"
err := userDataDb.SetSavePoint(sess, insertTransactionSavePointName)

if err != nil {
log.Errorf(c, "[accounts.CreateAccounts] failed to set save point \"%s\", because %s", insertTransactionSavePointName, err.Error())
return err
}

createdRows, err := sess.Insert(transaction)

if err != nil || createdRows < 1 { // maybe another transaction has same time
err = userDataDb.RollbackToSavePoint(sess, insertTransactionSavePointName)

if err != nil {
log.Errorf(c, "[accounts.CreateAccounts] failed to rollback to save point \"%s\", because %s", insertTransactionSavePointName, err.Error())
return err
}

sameSecondLatestTransaction := &models.Transaction{}
minTransactionTime := utils.GetMinTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime))
maxTransactionTime := utils.GetMaxTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime))
Expand Down
29 changes: 24 additions & 5 deletions pkg/services/transactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ func (s *TransactionService) CreateTransaction(c core.Context, transaction *mode
UpdatedUnixTime: now,
}

return s.UserDataDB(transaction.Uid).DoTransaction(c, func(sess *xorm.Session) error {
return s.doCreateTransaction(c, sess, transaction, transactionTagIndexes, tagIds, pictureIds, pictureUpdateModel)
userDataDb := s.UserDataDB(transaction.Uid)

return userDataDb.DoTransaction(c, func(sess *xorm.Session) error {
return s.doCreateTransaction(c, userDataDb, sess, transaction, transactionTagIndexes, tagIds, pictureIds, pictureUpdateModel)
})
}

Expand Down Expand Up @@ -355,12 +357,14 @@ func (s *TransactionService) BatchCreateTransactions(c core.Context, uid int64,
allTransactionTagIds[transaction.TransactionId] = uniqueTagIds
}

return s.UserDataDB(uid).DoTransaction(c, func(sess *xorm.Session) error {
userDataDb := s.UserDataDB(uid)

return userDataDb.DoTransaction(c, func(sess *xorm.Session) error {
for i := 0; i < len(transactions); i++ {
transaction := transactions[i]
transactionTagIndexes := allTransactionTagIndexes[transaction.TransactionId]
transactionTagIds := allTransactionTagIds[transaction.TransactionId]
err := s.doCreateTransaction(c, sess, transaction, transactionTagIndexes, transactionTagIds, nil, nil)
err := s.doCreateTransaction(c, userDataDb, sess, transaction, transactionTagIndexes, transactionTagIds, nil, nil)

if err != nil {
transactionUnixTime := utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime)
Expand Down Expand Up @@ -1562,7 +1566,7 @@ func (s *TransactionService) GetTransactionIds(transactions []*models.Transactio
return transactionIds
}

func (s *TransactionService) doCreateTransaction(c core.Context, sess *xorm.Session, transaction *models.Transaction, transactionTagIndexes []*models.TransactionTagIndex, tagIds []int64, pictureIds []int64, pictureUpdateModel *models.TransactionPictureInfo) error {
func (s *TransactionService) doCreateTransaction(c core.Context, database *datastore.Database, sess *xorm.Session, transaction *models.Transaction, transactionTagIndexes []*models.TransactionTagIndex, tagIds []int64, pictureIds []int64, pictureUpdateModel *models.TransactionPictureInfo) error {
// Get and verify source and destination account
sourceAccount, destinationAccount, err := s.getAccountModels(sess, transaction)

Expand Down Expand Up @@ -1646,6 +1650,14 @@ func (s *TransactionService) doCreateTransaction(c core.Context, sess *xorm.Sess
relatedTransaction = s.GetRelatedTransferTransaction(transaction)
}

insertTransactionSavePointName := "insert_transaction"
err = database.SetSavePoint(sess, insertTransactionSavePointName)

if err != nil {
log.Errorf(c, "[transactions.doCreateTransaction] failed to set save point \"%s\", because %s", insertTransactionSavePointName, err.Error())
return err
}

createdRows, err := sess.Insert(transaction)

if err != nil || createdRows < 1 { // maybe another transaction has same time
Expand All @@ -1655,6 +1667,13 @@ func (s *TransactionService) doCreateTransaction(c core.Context, sess *xorm.Sess
log.Warnf(c, "[transactions.doCreateTransaction] cannot create trasaction, regenerate transaction time value")
}

err = database.RollbackToSavePoint(sess, insertTransactionSavePointName)

if err != nil {
log.Errorf(c, "[transactions.doCreateTransaction] failed to rollback to save point \"%s\", because %s", insertTransactionSavePointName, err.Error())
return err
}

sameSecondLatestTransaction := &models.Transaction{}
minTransactionTime := utils.GetMinTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime))
maxTransactionTime := utils.GetMaxTransactionTimeFromUnixTime(utils.GetUnixTimeFromTransactionTime(transaction.TransactionTime))
Expand Down

0 comments on commit a9a37b0

Please sign in to comment.