Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use generic update func to eliminate tx propagation failure #720

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
run: make test
- name: Archive browser screenshots
if: always()
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: e2e-screenshots
path: internal/integration/screenshots/**/*.png
Expand Down
58 changes: 30 additions & 28 deletions internal/notifications/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,34 +69,36 @@ func (db *pgdb) create(ctx context.Context, nc *Config) error {
return sql.Error(err)
}

func (db *pgdb) update(ctx context.Context, id resource.ID, updateFunc func(*Config) error) (*Config, error) {
var nc *Config
err := db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
result, err := q.FindNotificationConfigurationForUpdate(ctx, id)
if err != nil {
return sql.Error(err)
}
nc = pgresult(result).toNotificationConfiguration()
if err := updateFunc(nc); err != nil {
return sql.Error(err)
}
params := sqlc.UpdateNotificationConfigurationByIDParams{
UpdatedAt: sql.Timestamptz(internal.CurrentTimestamp(nil)),
Enabled: sql.Bool(nc.Enabled),
Name: sql.String(nc.Name),
URL: sql.NullString(),
NotificationConfigurationID: nc.ID,
}
for _, t := range nc.Triggers {
params.Triggers = append(params.Triggers, sql.String(string(t)))
}
if nc.URL != nil {
params.URL = sql.String(*nc.URL)
}
_, err = q.UpdateNotificationConfigurationByID(ctx, params)
return err
})
return nc, err
func (db *pgdb) update(ctx context.Context, id resource.ID, updateFunc func(context.Context, *Config) error) (*Config, error) {
return sql.Updater(
ctx,
db.DB,
func(ctx context.Context, q *sqlc.Queries) (*Config, error) {
result, err := q.FindNotificationConfigurationForUpdate(ctx, id)
if err != nil {
return nil, sql.Error(err)
}
return pgresult(result).toNotificationConfiguration(), nil
},
updateFunc,
func(ctx context.Context, q *sqlc.Queries, nc *Config) error {
params := sqlc.UpdateNotificationConfigurationByIDParams{
UpdatedAt: sql.Timestamptz(internal.CurrentTimestamp(nil)),
Enabled: sql.Bool(nc.Enabled),
Name: sql.String(nc.Name),
URL: sql.NullString(),
NotificationConfigurationID: nc.ID,
}
for _, t := range nc.Triggers {
params.Triggers = append(params.Triggers, sql.String(string(t)))
}
if nc.URL != nil {
params.URL = sql.String(*nc.URL)
}
_, err := q.UpdateNotificationConfigurationByID(ctx, params)
return err
},
)
}

func (db *pgdb) list(ctx context.Context, workspaceID resource.ID) ([]*Config, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/notifications/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (s *Service) Create(ctx context.Context, workspaceID resource.ID, opts Crea

func (s *Service) Update(ctx context.Context, id resource.ID, opts UpdateConfigOptions) (*Config, error) {
var subject authz.Subject
updated, err := s.db.update(ctx, id, func(nc *Config) (err error) {
updated, err := s.db.update(ctx, id, func(ctx context.Context, nc *Config) (err error) {
subject, err = s.Authorize(ctx, authz.UpdateNotificationConfigurationAction, &authz.AccessRequest{ID: &nc.WorkspaceID})
if err != nil {
return err
Expand Down
54 changes: 26 additions & 28 deletions internal/organization/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,35 +85,33 @@ func (db *pgdb) create(ctx context.Context, org *Organization) error {
return nil
}

func (db *pgdb) update(ctx context.Context, name string, fn func(*Organization) error) (*Organization, error) {
var org *Organization
err := db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
result, err := q.FindOrganizationByNameForUpdate(ctx, sql.String(name))
if err != nil {
func (db *pgdb) update(ctx context.Context, name string, fn func(context.Context, *Organization) error) (*Organization, error) {
return sql.Updater(
ctx,
db.DB,
func(ctx context.Context, q *sqlc.Queries) (*Organization, error) {
result, err := q.FindOrganizationByNameForUpdate(ctx, sql.String(name))
if err != nil {
return nil, err
}
return row(result).toOrganization(), nil
},
fn,
func(ctx context.Context, q *sqlc.Queries, org *Organization) error {
_, err := q.UpdateOrganizationByName(ctx, sqlc.UpdateOrganizationByNameParams{
Name: sql.String(name),
NewName: sql.String(org.Name),
Email: sql.StringPtr(org.Email),
CollaboratorAuthPolicy: sql.StringPtr(org.CollaboratorAuthPolicy),
CostEstimationEnabled: sql.Bool(org.CostEstimationEnabled),
SessionRemember: sql.Int4Ptr(org.SessionRemember),
SessionTimeout: sql.Int4Ptr(org.SessionTimeout),
UpdatedAt: sql.Timestamptz(org.UpdatedAt),
AllowForceDeleteWorkspaces: sql.Bool(org.AllowForceDeleteWorkspaces),
})
return err
}
org = row(result).toOrganization()

if err := fn(org); err != nil {
return err
}
_, err = q.UpdateOrganizationByName(ctx, sqlc.UpdateOrganizationByNameParams{
Name: sql.String(name),
NewName: sql.String(org.Name),
Email: sql.StringPtr(org.Email),
CollaboratorAuthPolicy: sql.StringPtr(org.CollaboratorAuthPolicy),
CostEstimationEnabled: sql.Bool(org.CostEstimationEnabled),
SessionRemember: sql.Int4Ptr(org.SessionRemember),
SessionTimeout: sql.Int4Ptr(org.SessionTimeout),
UpdatedAt: sql.Timestamptz(org.UpdatedAt),
AllowForceDeleteWorkspaces: sql.Bool(org.AllowForceDeleteWorkspaces),
})
if err != nil {
return err
}
return nil
})
return org, err
},
)
}

func (db *pgdb) list(ctx context.Context, opts dbListOptions) (*resource.Page[*Organization], error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/organization/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (s *Service) Update(ctx context.Context, name string, opts UpdateOptions) (
if err != nil {
return nil, err
}
org, err := s.db.update(ctx, name, func(org *Organization) error {
org, err := s.db.update(ctx, name, func(ctx context.Context, org *Organization) error {
return org.Update(opts)
})
if err != nil {
Expand Down
127 changes: 66 additions & 61 deletions internal/run/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"sort"
"strconv"
"time"

"github.com/jackc/pgx/v5/pgtype"
"github.com/leg100/otf/internal"
Expand Down Expand Up @@ -211,81 +212,85 @@ func (db *pgdb) CreateRun(ctx context.Context, run *Run) error {
}

// UpdateStatus updates the run status as well as its plan and/or apply.
func (db *pgdb) UpdateStatus(ctx context.Context, runID resource.ID, fn func(*Run) error) (*Run, error) {
var run *Run
err := db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
// select ...for update
result, err := q.FindRunByIDForUpdate(ctx, runID)
if err != nil {
return sql.Error(err)
}
run = pgresult(result).toRun()

// Make copies of run attributes before update
runStatus := run.Status
planStatus := run.Plan.Status
applyStatus := run.Apply.Status
cancelSignaledAt := run.CancelSignaledAt
func (db *pgdb) UpdateStatus(ctx context.Context, runID resource.ID, fn func(context.Context, *Run) error) (*Run, error) {
var runStatus Status
var planStatus PhaseStatus
var applyStatus PhaseStatus
var cancelSignaledAt *time.Time

if err := fn(run); err != nil {
return err
}

if run.Status != runStatus {
_, err := q.UpdateRunStatus(ctx, sqlc.UpdateRunStatusParams{
Status: sql.String(string(run.Status)),
ID: run.ID,
})
return sql.Updater(
ctx,
db.DB,
func(ctx context.Context, q *sqlc.Queries) (*Run, error) {
result, err := q.FindRunByIDForUpdate(ctx, runID)
if err != nil {
return err
return nil, err
}
run := pgresult(result).toRun()
// Make copies of run attributes before update
runStatus = run.Status
planStatus = run.Plan.Status
applyStatus = run.Apply.Status
cancelSignaledAt = run.CancelSignaledAt
return run, nil
},
fn,
func(ctx context.Context, q *sqlc.Queries, run *Run) error {
if run.Status != runStatus {
_, err := q.UpdateRunStatus(ctx, sqlc.UpdateRunStatusParams{
Status: sql.String(string(run.Status)),
ID: run.ID,
})
if err != nil {
return err
}

if err := db.insertRunStatusTimestamp(ctx, run); err != nil {
return err
if err := db.insertRunStatusTimestamp(ctx, run); err != nil {
return err
}
}
}

if run.Plan.Status != planStatus {
_, err := q.UpdatePlanStatusByID(ctx, sqlc.UpdatePlanStatusByIDParams{
Status: sql.String(string(run.Plan.Status)),
RunID: run.ID,
})
if err != nil {
return err
}
if run.Plan.Status != planStatus {
_, err := q.UpdatePlanStatusByID(ctx, sqlc.UpdatePlanStatusByIDParams{
Status: sql.String(string(run.Plan.Status)),
RunID: run.ID,
})
if err != nil {
return err
}

if err := db.insertPhaseStatusTimestamp(ctx, run.Plan); err != nil {
return err
if err := db.insertPhaseStatusTimestamp(ctx, run.Plan); err != nil {
return err
}
}
}

if run.Apply.Status != applyStatus {
_, err := q.UpdateApplyStatusByID(ctx, sqlc.UpdateApplyStatusByIDParams{
Status: sql.String(string(run.Apply.Status)),
RunID: run.ID,
})
if err != nil {
return err
}
if run.Apply.Status != applyStatus {
_, err := q.UpdateApplyStatusByID(ctx, sqlc.UpdateApplyStatusByIDParams{
Status: sql.String(string(run.Apply.Status)),
RunID: run.ID,
})
if err != nil {
return err
}

if err := db.insertPhaseStatusTimestamp(ctx, run.Apply); err != nil {
return err
if err := db.insertPhaseStatusTimestamp(ctx, run.Apply); err != nil {
return err
}
}
}

if run.CancelSignaledAt != cancelSignaledAt && run.CancelSignaledAt != nil {
_, err := q.UpdateCancelSignaledAt(ctx, sqlc.UpdateCancelSignaledAtParams{
CancelSignaledAt: sql.Timestamptz(*run.CancelSignaledAt),
ID: run.ID,
})
if err != nil {
return err
if run.CancelSignaledAt != cancelSignaledAt && run.CancelSignaledAt != nil {
_, err := q.UpdateCancelSignaledAt(ctx, sqlc.UpdateCancelSignaledAtParams{
CancelSignaledAt: sql.Timestamptz(*run.CancelSignaledAt),
ID: run.ID,
})
if err != nil {
return err
}
}
}

return nil
})
return run, err
return nil
},
)
}

func (db *pgdb) CreatePlanReport(ctx context.Context, runID resource.ID, resource, output Report) error {
Expand Down
14 changes: 7 additions & 7 deletions internal/run/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (s *Service) EnqueuePlan(ctx context.Context, runID resource.ID) (run *Run,
return nil, err
}
err = s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
run, err = s.db.UpdateStatus(ctx, runID, func(run *Run) error {
run, err = s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) error {
return run.EnqueuePlan()
})
if err != nil {
Expand Down Expand Up @@ -284,7 +284,7 @@ func (s *Service) Delete(ctx context.Context, runID resource.ID) error {

// StartPhase starts a run phase.
func (s *Service) StartPhase(ctx context.Context, runID resource.ID, phase internal.PhaseType, _ PhaseStartOptions) (*Run, error) {
run, err := s.db.UpdateStatus(ctx, runID, func(run *Run) error {
run, err := s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) error {
return run.Start()
})
if err != nil {
Expand Down Expand Up @@ -317,7 +317,7 @@ func (s *Service) FinishPhase(ctx context.Context, runID resource.ID, phase inte
var run *Run
err := s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) (err error) {
var autoapply bool
run, err = s.db.UpdateStatus(ctx, runID, func(run *Run) (err error) {
run, err = s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) (err error) {
autoapply, err = run.Finish(phase, opts)
return err
})
Expand Down Expand Up @@ -391,7 +391,7 @@ func (s *Service) Apply(ctx context.Context, runID resource.ID) error {
return err
}
return s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
run, err := s.db.UpdateStatus(ctx, runID, func(run *Run) error {
run, err := s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) error {
return run.EnqueueApply()
})
if err != nil {
Expand Down Expand Up @@ -422,7 +422,7 @@ func (s *Service) Discard(ctx context.Context, runID resource.ID) error {
return err
}

_, err = s.db.UpdateStatus(ctx, runID, func(run *Run) error {
_, err = s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) error {
return run.Discard()
})
if err != nil {
Expand All @@ -443,7 +443,7 @@ func (s *Service) Cancel(ctx context.Context, runID resource.ID) error {
return s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
_, isUser := subject.(*user.User)

run, err := s.db.UpdateStatus(ctx, runID, func(run *Run) (err error) {
run, err := s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) (err error) {
return run.Cancel(isUser, false)
})
if err != nil {
Expand Down Expand Up @@ -477,7 +477,7 @@ func (s *Service) ForceCancel(ctx context.Context, runID resource.ID) error {
return err
}
return s.db.Tx(ctx, func(ctx context.Context, q *sqlc.Queries) error {
run, err := s.db.UpdateStatus(ctx, runID, func(run *Run) (err error) {
run, err := s.db.UpdateStatus(ctx, runID, func(ctx context.Context, run *Run) (err error) {
return run.Cancel(true, true)
})
if err != nil {
Expand Down
Loading
Loading