From 17f10bc06da8dce934d3ed55ccbada711f35064b Mon Sep 17 00:00:00 2001 From: Mike Tonks Date: Wed, 28 Sep 2022 16:19:04 +0100 Subject: [PATCH] Refactor tests to use testify require library Makes code cleaner and easier to read Co-authored-by: Andrea Rosa Co-authored-by: Kevin Intriago --- internal/app/concurrent_proxy_stage_test.go | 63 +++------ internal/app/proxy_stage_test.go | 149 ++++++-------------- 2 files changed, 60 insertions(+), 152 deletions(-) diff --git a/internal/app/concurrent_proxy_stage_test.go b/internal/app/concurrent_proxy_stage_test.go index 153de8f..d9e332b 100644 --- a/internal/app/concurrent_proxy_stage_test.go +++ b/internal/app/concurrent_proxy_stage_test.go @@ -11,6 +11,7 @@ import ( "github.com/form3tech-oss/pact-proxy/pkg/pactproxy" "github.com/pact-foundation/pact-go/dsl" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" ) const ( @@ -20,6 +21,7 @@ const ( type ConcurrentProxyStage struct { t *testing.T + require *require.Assertions proxy *pactproxy.PactProxy pact *dsl.Pact modifiedNameStatusCode int @@ -40,9 +42,10 @@ func NewConcurrentProxyStage(t *testing.T) (*ConcurrentProxyStage, *ConcurrentPr } s := &ConcurrentProxyStage{ - t: t, - pact: pact, - proxy: proxy, + t: t, + require: require.New(t), + pact: pact, + proxy: proxy, } t.Cleanup(func() { @@ -144,42 +147,28 @@ func (s *ConcurrentProxyStage) the_concurrent_requests_are_sent() { return nil }) - if err != nil { - s.t.Error(err) - } + s.require.NoError(err) } func (s *ConcurrentProxyStage) makeUserRequest() { u := fmt.Sprintf("http://localhost:%s/users", proxyURL.Port()) req, err := http.NewRequest("POST", u, strings.NewReader(`{"name":"jim"}`)) - if err != nil { - s.t.Error(err) - return - } + s.require.NoError(err) req.Header.Set("Content-Type", "application/json") res, err := http.DefaultClient.Do(req) - if err != nil { - s.t.Error(err) - return - } + s.require.NoError(err) s.userResponses = append(s.userResponses, res) } func (s *ConcurrentProxyStage) makeAddressRequest() { u := fmt.Sprintf("http://localhost:%s/addresses", proxyURL.Port()) req, err := http.NewRequest("POST", u, strings.NewReader(`{"address":"test"}`)) - if err != nil { - s.t.Error(err) - return - } + s.require.NoError(err) req.Header.Set("Content-Type", "application/json") res, err := http.DefaultClient.Do(req) - if err != nil { - s.t.Error(err) - return - } + s.require.NoError(err) s.addressResponses = append(s.addressResponses, res) } @@ -208,13 +197,10 @@ func sendConcurrentRequests(requests int, d time.Duration, f func()) { func (s *ConcurrentProxyStage) all_the_user_responses_should_have_the_right_status_code() *ConcurrentProxyStage { expectedLen := s.concurrentUserRequestsPerSecond * int(s.concurrentUserRequestsDuration/time.Second) - if len(s.userResponses) != expectedLen { - s.t.Errorf("expected %d user responses, but got %d", expectedLen, len(s.userResponses)) - } + s.require.Len(s.userResponses, expectedLen, "number of user responsesnot as expected") + for _, res := range s.userResponses { - if s.modifiedNameStatusCode != res.StatusCode { - s.t.Errorf("expected user status code of %d, but got %d", s.modifiedNameStatusCode, res.StatusCode) - } + s.require.Equal(res.StatusCode, s.modifiedNameStatusCode, "expected user status code") } return s @@ -222,13 +208,10 @@ func (s *ConcurrentProxyStage) all_the_user_responses_should_have_the_right_stat func (s *ConcurrentProxyStage) all_the_address_responses_should_have_the_right_status_code() *ConcurrentProxyStage { expectedLen := s.concurrentAddressRequestsPerSecond * int(s.concurrentAddressRequestsDuration/time.Second) - if len(s.addressResponses) != expectedLen { - s.t.Errorf("expected %d address responses, but got %d", expectedLen, len(s.addressResponses)) - } + s.require.Len(s.addressResponses, expectedLen, "number of address responses not as expected") + for _, res := range s.addressResponses { - if s.modifiedAddressStatusCode != res.StatusCode { - s.t.Errorf("expected address status code of %d, but got %d", s.modifiedAddressStatusCode, res.StatusCode) - } + s.require.Equal(res.StatusCode, s.modifiedAddressStatusCode, "expected address status code") } return s @@ -236,22 +219,14 @@ func (s *ConcurrentProxyStage) all_the_address_responses_should_have_the_right_s func (s *ConcurrentProxyStage) the_proxy_waits_for_all_user_responses() *ConcurrentProxyStage { want := s.concurrentUserRequestsPerSecond * int(s.concurrentUserRequestsDuration/time.Second) - received := len(s.userResponses) - if received != want { - s.t.Errorf("expected %d user responses, but got %d", want, received) - s.t.Fail() - } + s.require.Len(s.userResponses, want, "number of user responses, is not as expected") return s } func (s *ConcurrentProxyStage) the_proxy_waits_for_all_address_responses() *ConcurrentProxyStage { want := s.concurrentAddressRequestsPerSecond * int(s.concurrentAddressRequestsDuration/time.Second) - received := len(s.addressResponses) - if received != want { - s.t.Errorf("expected %d address responses, but got %d", want, received) - s.t.Fail() - } + s.require.Len(s.addressResponses, want, "number of address responses not as expected") return s } diff --git a/internal/app/proxy_stage_test.go b/internal/app/proxy_stage_test.go index 4d3c1ba..49c7353 100644 --- a/internal/app/proxy_stage_test.go +++ b/internal/app/proxy_stage_test.go @@ -5,10 +5,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/avast/retry-go/v4" - "github.com/form3tech-oss/pact-proxy/pkg/pactproxy" - "github.com/pact-foundation/pact-go/dsl" - "github.com/pkg/errors" "io" "net/http" "strconv" @@ -16,10 +12,17 @@ import ( "sync/atomic" "testing" "time" + + "github.com/avast/retry-go/v4" + "github.com/form3tech-oss/pact-proxy/pkg/pactproxy" + "github.com/pact-foundation/pact-go/dsl" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" ) type ProxyStage struct { t *testing.T + require *require.Assertions pact *dsl.Pact proxy *pactproxy.PactProxy contentTypeConstraint string @@ -47,6 +50,7 @@ func NewProxyStage(t *testing.T) (*ProxyStage, *ProxyStage, *ProxyStage) { s := &ProxyStage{ t: t, + require: require.New(t), proxy: proxy, pact: pact, modifiedBody: make(map[string]interface{}), @@ -326,24 +330,15 @@ func (s *ProxyStage) n_requests_are_sent_using_the_body_and_content_type(n int, func (s *ProxyStage) send_post_request_and_collect_response(body, url, contentType string) error { req, err := http.NewRequest("POST", url, strings.NewReader(body)) - if err != nil { - s.t.Errorf("request creation failed: %v", err) - return err - } + s.require.NoError(err, "request creation failed") req.Header.Set("Content-Type", contentType) res, err := http.DefaultClient.Do(req) - if err != nil { - s.t.Errorf("sending request failed: %v", err) - return err - } + s.require.NoError(err, "sending request failed") s.responses = append(s.responses, res) bodyBytes, err := io.ReadAll(res.Body) - if err != nil { - s.t.Errorf("unable to read response body: %v", err) - return err - } + s.require.NoError(err, "unable to read response body") s.responseBodies = append(s.responseBodies, bodyBytes) return nil } @@ -356,126 +351,82 @@ func (s *ProxyStage) multiple_requests_are_sent(requestsToSend int32) { for i := int32(0); i < requestsToSend; i++ { u := fmt.Sprintf("http://localhost:%s/users", proxyURL.Port()) req, err := http.NewRequest("POST", u, strings.NewReader(`{"name":"test"}`)) - if err != nil { - s.t.Error(err) - s.t.Fail() - } + s.require.NoError(err) req.Header.Set("Content-Type", "application/json") atomic.AddInt32(&s.requestsSent, 1) - if _, err = http.DefaultClient.Do(req); err != nil { - s.t.Error(err) - s.t.Fail() - } + _, err = http.DefaultClient.Do(req) + s.require.NoError(err) } }() - if err := s.proxy.WaitForInteraction(s.pactName, int(requestsToSend)); err != nil { - s.t.Error(err) - s.t.Fail() - } - + err = s.proxy.WaitForInteraction(s.pactName, int(requestsToSend)) + s.require.NoError(err) return nil }) } func (s *ProxyStage) pact_verification_is_successful() *ProxyStage { - if s.pactResult != nil { - s.t.Error(s.pactResult) - s.t.Fail() - } + s.require.Nil(s.pactResult) return s } func (s *ProxyStage) pact_verification_is_not_successful() *ProxyStage { - if s.pactResult == nil { - s.t.Error("pact verification did not fail") - s.t.Fail() - } + s.require.NotNil(s.pactResult, "pact verification did not fail") return s } func (s *ProxyStage) the_proxy_waits_for_all_requests() *ProxyStage { sent := atomic.LoadInt32(&s.requestsSent) - if sent != s.requestsToSend { - s.t.Errorf("proxy did not wait for requests, sent=%d expected=%d", sent, s.requestsToSend) - s.t.Fail() - } + s.require.Equal(s.requestsToSend, sent, "proxy did not wait for requests") return s } func (s *ProxyStage) the_response_is_(statusCode int) *ProxyStage { s.the_nth_response_is_(1, statusCode) - return s } func (s *ProxyStage) the_response_name_is_(name string) *ProxyStage { s.the_nth_response_name_is_(1, name) - return s } func (s *ProxyStage) the_nth_response_is_(n, statusCode int) *ProxyStage { - if len(s.responses) < n { - s.t.Fatalf("Expected at least %d responses, got %d", n, len(s.responses)) - } - - if s.responses[n-1].StatusCode != statusCode { - s.t.Fatalf("Expected status code on attempt %d: %d, got : %d", n, statusCode, s.responses[n-1].StatusCode) - } - + s.require.GreaterOrEqual(len(s.responses), n, "number of responses is less than expected") + s.require.Equalf(statusCode, s.responses[n-1].StatusCode, "Expected status code on attempt %d: %d, got : %d", n, statusCode, s.responses[n-1].StatusCode) return s } func (s *ProxyStage) the_nth_response_name_is_(n int, name string) *ProxyStage { - if len(s.responses) < n { - s.t.Fatalf("Expected at least %d responses, got %d", n, len(s.responses)) - } + s.require.GreaterOrEqual(len(s.responses), n, "number of responses is less than expected") var body map[string]string - if err := json.Unmarshal(s.responseBodies[n-1], &body); err != nil { - s.t.Fatalf("unable to parse response body, %v", err) - } - - if body["name"] != name { - s.t.Fatalf("Expected name on attempt %d,: %s, got: %s", n, name, body["name"]) - } - + err := json.Unmarshal(s.responseBodies[n-1], &body) + s.require.NoError(err, "unable to parse response body, %v", err) + s.require.Equalf(name, body["name"], "Expected name on attempt %d,: %s, got: %s", n, name, body["name"]) return s } func (s *ProxyStage) the_nth_response_age_is_(n int, age int64) *ProxyStage { - if len(s.responses) < n { - s.t.Fatalf("Expected at least %d responses, got %d", n, len(s.responses)) - } + s.require.GreaterOrEqual(len(s.responses), n, "number of responses is less than expected") var body map[string]int64 - if err := json.Unmarshal(s.responseBodies[n-1], &body); err != nil { - s.t.Fatalf("unable to parse response body, %v", err) - } + err := json.Unmarshal(s.responseBodies[n-1], &body) + s.require.NoError(err, "unable to parse response body") - if body["age"] != age { - s.t.Fatalf("Expected name on attempt %d,: %d, got: %d", n, age, body["age"]) - } + s.require.Equalf(age, body["age"], "Expected name on attempt %d,: %d, got: %d", n, age, body["age"]) return s } func (s *ProxyStage) the_nth_response_body_has_(n int, key, value string) *ProxyStage { - if len(s.responseBodies) < n { - s.t.Fatalf("Expected at least %d responses, got %d", n, len(s.responseBodies)) - } + s.require.GreaterOrEqual(len(s.responseBodies), n, "number of request bodies is les than expected") var responseBody map[string]string - if err := json.Unmarshal(s.responseBodies[n-1], &responseBody); err != nil { - s.t.Fatalf("unable to parse response body, %v", err) - } - - if responseBody[key] != value { - s.t.Fatalf("Expected %s on attempt %d,: %s, got: %s", key, n, value, responseBody[key]) - } - + err := json.Unmarshal(s.responseBodies[n-1], &responseBody) + s.require.NoError(err, "unable to parse response body, %v", err) + s.require.Equalf(value, responseBody[key], "Expected %s on attempt %d,: %s, got: %s", key, n, value, responseBody[key]) return s } @@ -488,46 +439,28 @@ func (s *ProxyStage) the_response_body_to_plain_text_request_is_correct() *Proxy } func (s *ProxyStage) the_nth_response_body_is(n int, data []byte) *ProxyStage { - if len(s.responseBodies) < n { - s.t.Fatalf("Expected at least %d responses, got %d", n, len(s.responseBodies)) - } + s.require.GreaterOrEqual(len(s.responseBodies), n, "number of request bodies is les than expected") body := s.responseBodies[n-1] - if c := bytes.Compare(body, data); c != 0 { - s.t.Fatalf("Expected body did not match. Expected: %s, got: %s", data, body) - } - + c := bytes.Compare(body, data) + s.require.Equal(0, c, "Expected body did not match") return s } func (s *ProxyStage) n_responses_were_received(n int) *ProxyStage { - count := len(s.responses) - if count != n { - s.t.Fatalf("Expected %d responses, got %d", n, count) - } - + s.require.Len(s.responses, n) return s } func (s *ProxyStage) pact_can_be_generated() { u := fmt.Sprintf("http://localhost:%s/pact", proxyURL.Port()) req, err := http.NewRequestWithContext(context.Background(), "POST", u, bytes.NewReader([]byte("{\"pact_specification_version\":\"3.0.0\"}"))) - if err != nil { - s.t.Error(err) - return - } + s.require.NoError(err) req.Header.Add("X-Pact-Mock-Service", "true") req.Header.Add("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) - if err != nil { - s.t.Error(err) - return - } - - if resp.StatusCode != http.StatusOK { - s.t.Fatalf("Expected 200 but returned %d status code", resp.StatusCode) - } - - defer func() { _ = resp.Body.Close() }() + s.require.NoError(err) + defer resp.Body.Close() + s.require.Equal(http.StatusOK, resp.StatusCode, "Expected 200 but returned %d status code", resp.StatusCode) }