Skip to content

Commit

Permalink
fix: use generic update func to eliminate tx propagation failure (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
leg100 authored Jan 16, 2025
1 parent 1d9d18e commit 900fbe7
Show file tree
Hide file tree
Showing 17 changed files with 345 additions and 312 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
run: make test
- name: Archive browser screenshots
if: always()
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: e2e-screenshots
path: internal/integration/screenshots/**/*.png
Expand Down
58 changes: 30 additions & 28 deletions internal/notifications/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,34 +69,36 @@ func (db *pgdb) create(ctx context.Context, nc *Config) error {
return sql.Error(err)
}

func (db *pgdb) update(ctx context.Context, id resource.ID, updateFunc func(*Config) error) (*Config, error) {
var nc *Config
err := db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
result, err := q.FindNotificationConfigurationForUpdate(ctx, id)
if err != nil {
return sql.Error(err)
}
nc = pgresult(result).toNotificationConfiguration()
if err := updateFunc(nc); err != nil {
return sql.Error(err)
}
params := sqlc.UpdateNotificationConfigurationByIDParams{
UpdatedAt: sql.Timestamptz(internal.CurrentTimestamp(nil)),
Enabled: sql.Bool(nc.Enabled),
Name: sql.String(nc.Name),
URL: sql.NullString(),
NotificationConfigurationID: nc.ID,
}
for _, t := range nc.Triggers {
params.Triggers = append(params.Triggers, sql.String(string(t)))
}
if nc.URL != nil {
params.URL = sql.String(*nc.URL)
}
_, err = q.UpdateNotificationConfigurationByID(ctx, params)
return err
})
return nc, err
func (db *pgdb) update(ctx context.Context, id resource.ID, updateFunc func(context.Context, *Config) error) (*Config, error) {
return sql.Updater(
ctx,
db.DB,
func(ctx context.Context, q *sqlc.Queries) (*Config, error) {
result, err := q.FindNotificationConfigurationForUpdate(ctx, id)
if err != nil {
return nil, sql.Error(err)
}
return pgresult(result).toNotificationConfiguration(), nil
},
updateFunc,
func(ctx context.Context, q *sqlc.Queries, nc *Config) error {
params := sqlc.UpdateNotificationConfigurationByIDParams{
UpdatedAt: sql.Timestamptz(internal.CurrentTimestamp(nil)),
Enabled: sql.Bool(nc.Enabled),
Name: sql.String(nc.Name),
URL: sql.NullString(),
NotificationConfigurationID: nc.ID,
}
for _, t := range nc.Triggers {
params.Triggers = append(params.Triggers, sql.String(string(t)))
}
if nc.URL != nil {
params.URL = sql.String(*nc.URL)
}
_, err := q.UpdateNotificationConfigurationByID(ctx, params)
return err
},
)
}

func (db *pgdb) list(ctx context.Context, workspaceID resource.ID) ([]*Config, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/notifications/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (s *Service) Create(ctx context.Context, workspaceID resource.ID, opts Crea

func (s *Service) Update(ctx context.Context, id resource.ID, opts UpdateConfigOptions) (*Config, error) {
var subject authz.Subject
updated, err := s.db.update(ctx, id, func(nc *Config) (err error) {
updated, err := s.db.update(ctx, id, func(ctx context.Context, nc *Config) (err error) {
subject, err = s.Authorize(ctx, authz.UpdateNotificationConfigurationAction, &authz.AccessRequest{ID: &nc.WorkspaceID})
if err != nil {
return err
Expand Down
54 changes: 26 additions & 28 deletions internal/organization/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,35 +85,33 @@ func (db *pgdb) create(ctx context.Context, org *Organization) error {
return nil
}

func (db *pgdb) update(ctx context.Context, name string, fn func(*Organization) error) (*Organization, error) {
var org *Organization
err := db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
result, err := q.FindOrganizationByNameForUpdate(ctx, sql.String(name))
if err != nil {
func (db *pgdb) update(ctx context.Context, name string, fn func(context.Context, *Organization) error) (*Organization, error) {
return sql.Updater(
ctx,
db.DB,
func(ctx context.Context, q *sqlc.Queries) (*Organization, error) {
result, err := q.FindOrganizationByNameForUpdate(ctx, sql.String(name))
if err != nil {
return nil, err
}
return row(result).toOrganization(), nil
},
fn,
func(ctx context.Context, q *sqlc.Queries, org *Organization) error {
_, err := q.UpdateOrganizationByName(ctx, sqlc.UpdateOrganizationByNameParams{
Name: sql.String(name),
NewName: sql.String(org.Name),
Email: sql.StringPtr(org.Email),
CollaboratorAuthPolicy: sql.StringPtr(org.CollaboratorAuthPolicy),
CostEstimationEnabled: sql.Bool(org.CostEstimationEnabled),
SessionRemember: sql.Int4Ptr(org.SessionRemember),
SessionTimeout: sql.Int4Ptr(org.SessionTimeout),
UpdatedAt: sql.Timestamptz(org.UpdatedAt),
AllowForceDeleteWorkspaces: sql.Bool(org.AllowForceDeleteWorkspaces),
})
return err
}
org = row(result).toOrganization()

if err := fn(org); err != nil {
return err
}
_, err = q.UpdateOrganizationByName(ctx, sqlc.UpdateOrganizationByNameParams{
Name: sql.String(name),
NewName: sql.String(org.Name),
Email: sql.StringPtr(org.Email),
CollaboratorAuthPolicy: sql.StringPtr(org.CollaboratorAuthPolicy),
CostEstimationEnabled: sql.Bool(org.CostEstimationEnabled),
SessionRemember: sql.Int4Ptr(org.SessionRemember),
SessionTimeout: sql.Int4Ptr(org.SessionTimeout),
UpdatedAt: sql.Timestamptz(org.UpdatedAt),
AllowForceDeleteWorkspaces: sql.Bool(org.AllowForceDeleteWorkspaces),
})
if err != nil {
return err
}
return nil
})
return org, err
},
)
}

func (db *pgdb) list(ctx context.Context, opts dbListOptions) (*resource.Page[*Organization], error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/organization/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (s *Service) Update(ctx context.Context, name string, opts UpdateOptions) (
if err != nil {
return nil, err
}
org, err := s.db.update(ctx, name, func(org *Organization) error {
org, err := s.db.update(ctx, name, func(ctx context.Context, org *Organization) error {
return org.Update(opts)
})
if err != nil {
Expand Down
127 changes: 66 additions & 61 deletions internal/run/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"sort"
"strconv"
"time"

"github.com/jackc/pgx/v5/pgtype"
"github.com/leg100/otf/internal"
Expand Down Expand Up @@ -211,81 +212,85 @@ func (db *pgdb) CreateRun(ctx context.Context, run *Run) error {
}

// UpdateStatus updates the run status as well as its plan and/or apply.
func (db *pgdb) UpdateStatus(ctx context.Context, runID resource.ID, fn func(*Run) error) (*Run, error) {
var run *Run
err := db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
// select ...for update
result, err := q.FindRunByIDForUpdate(ctx, runID)
if err != nil {
return sql.Error(err)
}
run = pgresult(result).toRun()

// Make copies of run attributes before update
runStatus := run.Status
planStatus := run.Plan.Status
applyStatus := run.Apply.Status
cancelSignaledAt := run.CancelSignaledAt
func (db *pgdb) UpdateStatus(ctx context.Context, runID resource.ID, fn func(context.Context, *Run) error) (*Run, error) {
var runStatus Status
var planStatus PhaseStatus
var applyStatus PhaseStatus
var cancelSignaledAt *time.Time

if err := fn(run); err != nil {
return err
}

if run.Status != runStatus {
_, err := q.UpdateRunStatus(ctx, sqlc.UpdateRunStatusParams{
Status: sql.String(string(run.Status)),
ID: run.ID,
})
return sql.Updater(
ctx,
db.DB,
func(ctx context.Context, q *sqlc.Queries) (*Run, error) {
result, err := q.FindRunByIDForUpdate(ctx, runID)
if err != nil {
return err
return nil, err
}
run := pgresult(result).toRun()
// Make copies of run attributes before update
runStatus = run.Status
planStatus = run.Plan.Status
applyStatus = run.Apply.Status
cancelSignaledAt = run.CancelSignaledAt
return run, nil
},
fn,
func(ctx context.Context, q *sqlc.Queries, run *Run) error {
if run.Status != runStatus {
_, err := q.UpdateRunStatus(ctx, sqlc.UpdateRunStatusParams{
Status: sql.String(string(run.Status)),
ID: run.ID,
})
if err != nil {
return err
}

if err := db.insertRunStatusTimestamp(ctx, run); err != nil {
return err
if err := db.insertRunStatusTimestamp(ctx, run); err != nil {
return err
}
}
}

if run.Plan.Status != planStatus {
_, err := q.UpdatePlanStatusByID(ctx, sqlc.UpdatePlanStatusByIDParams{
Status: sql.String(string(run.Plan.Status)),
RunID: run.ID,
})
if err != nil {
return err
}
if run.Plan.Status != planStatus {
_, err := q.UpdatePlanStatusByID(ctx, sqlc.UpdatePlanStatusByIDParams{
Status: sql.String(string(run.Plan.Status)),
RunID: run.ID,
})
if err != nil {
return err
}

if err := db.insertPhaseStatusTimestamp(ctx, run.Plan); err != nil {
return err
if err := db.insertPhaseStatusTimestamp(ctx, run.Plan); err != nil {
return err
}
}
}

if run.Apply.Status != applyStatus {
_, err := q.UpdateApplyStatusByID(ctx, sqlc.UpdateApplyStatusByIDParams{
Status: sql.String(string(run.Apply.Status)),
RunID: run.ID,
})
if err != nil {
return err
}
if run.Apply.Status != applyStatus {
_, err := q.UpdateApplyStatusByID(ctx, sqlc.UpdateApplyStatusByIDParams{
Status: sql.String(string(run.Apply.Status)),
RunID: run.ID,
})
if err != nil {
return err
}

if err := db.insertPhaseStatusTimestamp(ctx, run.Apply); err != nil {
return err
if err := db.insertPhaseStatusTimestamp(ctx, run.Apply); err != nil {
return err
}
}
}

if run.CancelSignaledAt != cancelSignaledAt && run.CancelSignaledAt != nil {
_, err := q.UpdateCancelSignaledAt(ctx, sqlc.UpdateCancelSignaledAtParams{
CancelSignaledAt: sql.Timestamptz(*run.CancelSignaledAt),
ID: run.ID,
})
if err != nil {
return err
if run.CancelSignaledAt != cancelSignaledAt && run.CancelSignaledAt != nil {
_, err := q.UpdateCancelSignaledAt(ctx, sqlc.UpdateCancelSignaledAtParams{
CancelSignaledAt: sql.Timestamptz(*run.CancelSignaledAt),
ID: run.ID,
})
if err != nil {
return err
}
}
}

return nil
})
return run, err
return nil
},
)
}

func (db *pgdb) CreatePlanReport(ctx context.Context, runID resource.ID, resource, output Report) error {
Expand Down
14 changes: 7 additions & 7 deletions internal/run/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (s *Service) EnqueuePlan(ctx context.Context, runID resource.ID) (run *Run,
return nil, err
}
err = s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
run, err = s.db.UpdateStatus(ctx, runID, func(run *Run) error {
run, err = s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) error {
return run.EnqueuePlan()
})
if err != nil {
Expand Down Expand Up @@ -284,7 +284,7 @@ func (s *Service) Delete(ctx context.Context, runID resource.ID) error {

// StartPhase starts a run phase.
func (s *Service) StartPhase(ctx context.Context, runID resource.ID, phase internal.PhaseType, _ PhaseStartOptions) (*Run, error) {
run, err := s.db.UpdateStatus(ctx, runID, func(run *Run) error {
run, err := s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) error {
return run.Start()
})
if err != nil {
Expand Down Expand Up @@ -317,7 +317,7 @@ func (s *Service) FinishPhase(ctx context.Context, runID resource.ID, phase inte
var run *Run
err := s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) (err error) {
var autoapply bool
run, err = s.db.UpdateStatus(ctx, runID, func(run *Run) (err error) {
run, err = s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) (err error) {
autoapply, err = run.Finish(phase, opts)
return err
})
Expand Down Expand Up @@ -391,7 +391,7 @@ func (s *Service) Apply(ctx context.Context, runID resource.ID) error {
return err
}
return s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
run, err := s.db.UpdateStatus(ctx, runID, func(run *Run) error {
run, err := s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) error {
return run.EnqueueApply()
})
if err != nil {
Expand Down Expand Up @@ -422,7 +422,7 @@ func (s *Service) Discard(ctx context.Context, runID resource.ID) error {
return err
}

_, err = s.db.UpdateStatus(ctx, runID, func(run *Run) error {
_, err = s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) error {
return run.Discard()
})
if err != nil {
Expand All @@ -443,7 +443,7 @@ func (s *Service) Cancel(ctx context.Context, runID resource.ID) error {
return s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
_, isUser := subject.(*user.User)

run, err := s.db.UpdateStatus(ctx, runID, func(run *Run) (err error) {
run, err := s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) (err error) {
return run.Cancel(isUser, false)
})
if err != nil {
Expand Down Expand Up @@ -477,7 +477,7 @@ func (s *Service) ForceCancel(ctx context.Context, runID resource.ID) error {
return err
}
return s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
run, err := s.db.UpdateStatus(ctx, runID, func(run *Run) (err error) {
run, err := s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) (err error) {
return run.Cancel(true, true)
})
if err != nil {
Expand Down
Loading

0 comments on commit 900fbe7

Please sign in to comment.