diff --git a/internal/app/pactproxy/interaction.go b/internal/app/pactproxy/interaction.go index a5c7e7f..c38e1af 100644 --- a/internal/app/pactproxy/interaction.go +++ b/internal/app/pactproxy/interaction.go @@ -8,7 +8,6 @@ import ( "regexp" "strings" "sync" - "sync/atomic" log "github.com/sirupsen/logrus" @@ -42,15 +41,16 @@ func (m *regexPathMatcher) match(val string) bool { } type interaction struct { + mu sync.RWMutex pathMatcher pathMatcher method string Alias string Description string definition map[string]interface{} - constraints sync.Map + constraints map[string]interactionConstraint Modifiers *interactionModifiers - lastRequest atomic.Value - requestCount int32 + lastRequest requestDocument + requestCount int } func LoadInteraction(data []byte, alias string) (*interaction, error) { @@ -95,11 +95,12 @@ func LoadInteraction(data []byte, alias string) (*interaction, error) { Alias: alias, definition: definition, Description: description, + constraints: map[string]interactionConstraint{}, } interaction.Modifiers = &interactionModifiers{ interaction: interaction, - modifiers: sync.Map{}, + modifiers: map[string]*interactionModifier{}, } requestBody, ok := request["body"] @@ -296,7 +297,9 @@ func (i *interaction) Match(path, method string) bool { } func (i *interaction) AddConstraint(constraint interactionConstraint) { - i.constraints.Store(constraint.Key(), constraint) + i.mu.Lock() + defer i.mu.Unlock() + i.constraints[constraint.Key()] = constraint } func (i *interaction) loadValuesFromSource(constraint interactionConstraint, interactions *Interactions) ([]interface{}, error) { @@ -306,8 +309,10 @@ func (i *interaction) loadValuesFromSource(constraint interactionConstraint, int return nil, errors.Errorf("cannot find source interaction '%s' for constraint", constraint.Source) } - sourceRequest, ok := sourceInteraction.lastRequest.Load().(requestDocument) - if !ok { + i.mu.RLock() + sourceRequest := sourceInteraction.lastRequest + i.mu.RUnlock() + if len(sourceRequest) == 0 { return nil, errors.Errorf("source interaction '%s' as no requests", constraint.Source) } @@ -322,8 +327,9 @@ func (i *interaction) EvaluateConstrains(request requestDocument, interactions * result := true violations := make([]string, 0) - i.constraints.Range(func(_, v interface{}) bool { - constraint := v.(interactionConstraint) + i.mu.RLock() + defer i.mu.RUnlock() + for _, constraint := range i.constraints { values := constraint.Values if constraint.Source != "" { var err error @@ -331,7 +337,7 @@ func (i *interaction) EvaluateConstrains(request requestDocument, interactions * if err != nil { violations = append(violations, err.Error()) result = false - return true + continue } } @@ -342,7 +348,7 @@ func (i *interaction) EvaluateConstrains(request requestDocument, interactions * } if reflect.TypeOf(val) == reflect.TypeOf([]interface{}{}) { log.Infof("skipping matching on interface{} type for path '%s'", constraint.Path) - return true + continue } if err == nil { actual = fmt.Sprintf("%v", val) @@ -353,18 +359,24 @@ func (i *interaction) EvaluateConstrains(request requestDocument, interactions * violations = append(violations, fmt.Sprintf("value '%s' at path '%s' does not match constraint '%s'", actual, constraint.Path, expected)) result = false } - - return true - }) + } return result, violations } func (i *interaction) StoreRequest(request requestDocument) { - i.lastRequest.Store(request) - atomic.AddInt32(&i.requestCount, 1) + i.mu.Lock() + defer i.mu.Unlock() + i.lastRequest = request + i.requestCount++ } func (i *interaction) HasRequests(count int) bool { - return atomic.LoadInt32(&i.requestCount) >= int32(count) + return i.requestCount >= i.getRequestCount() +} + +func (i *interaction) getRequestCount() int { + i.mu.RLock() + defer i.mu.RUnlock() + return i.requestCount } diff --git a/internal/app/pactproxy/interaction_test.go b/internal/app/pactproxy/interaction_test.go index 5b75b86..09bb669 100644 --- a/internal/app/pactproxy/interaction_test.go +++ b/internal/app/pactproxy/interaction_test.go @@ -2,9 +2,10 @@ package pactproxy import ( "encoding/json" + "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" ) func TestLoadInteractionPlainTextConstraints(t *testing.T) { @@ -104,18 +105,17 @@ func TestLoadInteractionPlainTextConstraints(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := LoadInteraction(tt.interaction, "alias") + interaction, err := LoadInteraction(tt.interaction, "alias") require.Equalf(t, tt.wantErr, err != nil, "error %v", err) - var gotConstraint interactionConstraint - got.constraints.Range(func(key, value interface{}) bool { - var present bool - gotConstraint, present = value.(interactionConstraint) - return present - }) + var foundConstraint interactionConstraint + for _, constraint := range interaction.constraints { + foundConstraint = constraint + break + } - assert.EqualValues(t, tt.wantConstraint, gotConstraint) + assert.EqualValues(t, tt.wantConstraint, foundConstraint) }) } } @@ -239,18 +239,17 @@ func TestV3MatchingRulesLeadToCorrectConstraints(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := LoadInteraction(tt.interaction, "alias") + interaction, err := LoadInteraction(tt.interaction, "alias") require.Equalf(t, tt.wantErr, err != nil, "error %v", err) - var gotConstraint interactionConstraint - got.constraints.Range(func(key, value interface{}) bool { - var present bool - gotConstraint, present = value.(interactionConstraint) - return present - }) + var foundConstraint interactionConstraint + for _, constraint := range interaction.constraints { + foundConstraint = constraint + break + } - assert.EqualValues(t, tt.wantConstraint, gotConstraint) + assert.EqualValues(t, tt.wantConstraint, foundConstraint) }) } } diff --git a/internal/app/pactproxy/modifier.go b/internal/app/pactproxy/modifier.go index 72d00b0..e3ad20d 100644 --- a/internal/app/pactproxy/modifier.go +++ b/internal/app/pactproxy/modifier.go @@ -4,8 +4,6 @@ import ( "fmt" "strconv" "strings" - "sync" - "sync/atomic" "github.com/tidwall/sjson" ) @@ -19,7 +17,7 @@ type interactionModifier struct { type interactionModifiers struct { interaction *interaction - modifiers sync.Map + modifiers map[string]*interactionModifier } func (im *interactionModifier) Key() string { @@ -27,22 +25,25 @@ func (im *interactionModifier) Key() string { } func (ims *interactionModifiers) AddModifier(modifier *interactionModifier) { - ims.modifiers.Store(modifier.Key(), modifier) + ims.interaction.mu.Lock() + defer ims.interaction.mu.Unlock() + ims.modifiers[modifier.Key()] = modifier } func (ims *interactionModifiers) Modifiers() []*interactionModifier { var result []*interactionModifier - ims.modifiers.Range(func(_, v interface{}) bool { - result = append(result, v.(*interactionModifier)) - return true - }) + ims.interaction.mu.RLock() + defer ims.interaction.mu.RUnlock() + for _, modifier := range ims.modifiers { + result = append(result, modifier) + } return result } func (ims *interactionModifiers) modifyBody(b []byte) ([]byte, error) { for _, m := range ims.Modifiers() { if strings.HasPrefix(m.Path, "$.body.") { - if m.Attempt == nil || *m.Attempt == int(atomic.LoadInt32(&ims.interaction.requestCount)) { + if m.Attempt == nil || *m.Attempt == ims.interaction.getRequestCount() { var err error b, err = sjson.SetBytes(b, m.Path[7:], m.Value) if err != nil { @@ -57,7 +58,7 @@ func (ims *interactionModifiers) modifyBody(b []byte) ([]byte, error) { func (ims *interactionModifiers) modifyStatusCode() (bool, int) { for _, m := range ims.Modifiers() { if m.Path == "$.status" { - if m.Attempt == nil || *m.Attempt == int(atomic.LoadInt32(&ims.interaction.requestCount)) { + if m.Attempt == nil || *m.Attempt == ims.interaction.getRequestCount() { code, err := strconv.Atoi(fmt.Sprintf("%v", m.Value)) if err == nil { return true, code diff --git a/internal/app/pactproxy/proxy_test.go b/internal/app/pactproxy/proxy_test.go index 904c8e7..df7b5f8 100644 --- a/internal/app/pactproxy/proxy_test.go +++ b/internal/app/pactproxy/proxy_test.go @@ -46,10 +46,7 @@ func TestInteractionsWaitHandler(t *testing.T) { name: "timing out existing interaction", interactions: func() *Interactions { interactions := Interactions{} - interactions.Store(&interaction{ - Alias: "existing", - Description: "Existing", - }) + interactions.Store(newInteraction("existing")) return &interactions }(), req: func() *http.Request { @@ -72,3 +69,16 @@ func TestInteractionsWaitHandler(t *testing.T) { }) } } + +func newInteraction(alias string) *interaction { + i := &interaction{ + Alias: alias, + Description: alias, + constraints: map[string]interactionConstraint{}, + } + i.Modifiers = &interactionModifiers{ + interaction: i, + modifiers: map[string]*interactionModifier{}, + } + return i +}