diff --git a/binder.go b/binder.go index 3b6bb8bcb..544e90fa1 100644 --- a/binder.go +++ b/binder.go @@ -45,12 +45,12 @@ func (binder Binder) RoundTrip(origReq *http.Request) (*http.Response, error) { req.Proto = fmt.Sprintf("HTTP/%d.%d", req.ProtoMajor, req.ProtoMinor) } - if req.Body != nil { + if req.Body != nil && req.Body != http.NoBody { if req.ContentLength == -1 { req.TransferEncoding = []string{"chunked"} } } else { - req.Body = ioutil.NopCloser(bytes.NewReader(nil)) + req.Body = http.NoBody } if req.URL != nil && req.URL.Scheme == "https" && binder.TLS != nil { diff --git a/body_wrapper.go b/body_wrapper.go new file mode 100644 index 000000000..e3d7e07d6 --- /dev/null +++ b/body_wrapper.go @@ -0,0 +1,170 @@ +package httpexpect + +import ( + "bytes" + "context" + "io" + "io/ioutil" + "runtime" + "sync" +) + +// Wrapper for request or response body reader +// Allows to read body multiple times using two approaches: +// - use Read to read body contents and Rewind to restart reading from beginning +// - use GetBody to get new reader for body contents +type bodyWrapper struct { + currReader io.Reader + + origReader io.ReadCloser + origBytes []byte + origErr error + + cancelFunc context.CancelFunc + + isInitialized bool + + mu sync.Mutex +} + +func newBodyWrapper(reader io.ReadCloser, cancelFunc context.CancelFunc) *bodyWrapper { + bw := &bodyWrapper{ + origReader: reader, + cancelFunc: cancelFunc, + } + + // This is not strictly necessary because we should always call close. + // This is just a reinsurance. + runtime.SetFinalizer(bw, (*bodyWrapper).Close) + + return bw +} + +// Read body contents +func (bw *bodyWrapper) Read(p []byte) (n int, err error) { + bw.mu.Lock() + defer bw.mu.Unlock() + + // Preserve original reader error + if bw.origErr != nil { + return 0, bw.origErr + } + + // Lazy initialization + if !bw.isInitialized { + if err := bw.initialize(); err != nil { + return 0, err + } + } + + if bw.currReader == nil { + bw.currReader = bytes.NewReader(bw.origBytes) + } + return bw.currReader.Read(p) +} + +// Close body +func (bw *bodyWrapper) Close() error { + bw.mu.Lock() + defer bw.mu.Unlock() + + err := bw.origErr + + // Rewind or GetBody may be called later, so be sure to + // read body into memory before closing + if !bw.isInitialized { + initErr := bw.initialize() + if initErr != nil { + err = initErr + } + } + + // Close original reader + closeErr := bw.closeAndCancel() + if closeErr != nil { + err = closeErr + } + + return err +} + +// Rewind reading to the beginning +func (bw *bodyWrapper) Rewind() { + bw.mu.Lock() + defer bw.mu.Unlock() + + // Until first read, rewind is no-op + if !bw.isInitialized { + return + } + + // Reset reader + bw.currReader = bytes.NewReader(bw.origBytes) +} + +// Create new reader to retrieve body contents +// New reader always reads body from the beginning +// Does not affected by Rewind() +func (bw *bodyWrapper) GetBody() (io.ReadCloser, error) { + bw.mu.Lock() + defer bw.mu.Unlock() + + // Preserve original reader error + if bw.origErr != nil { + return nil, bw.origErr + } + + // Lazy initialization + if !bw.isInitialized { + if err := bw.initialize(); err != nil { + return nil, err + } + } + + return ioutil.NopCloser(bytes.NewReader(bw.origBytes)), nil +} + +func (bw *bodyWrapper) initialize() error { + if !bw.isInitialized { + bw.isInitialized = true + + if bw.origReader != nil { + bw.origBytes, bw.origErr = ioutil.ReadAll(bw.origReader) + + _ = bw.closeAndCancel() + + if bw.origErr != nil { + return bw.origErr + } + } + } + + return nil +} + +func (bw *bodyWrapper) closeAndCancel() error { + if bw.origReader == nil && bw.cancelFunc == nil { + return nil + } + + var err error + + if bw.origReader != nil { + err = bw.origReader.Close() + bw.origReader = nil + + if bw.origErr == nil { + bw.origErr = err + } + } + + if bw.cancelFunc != nil { + bw.cancelFunc() + bw.cancelFunc = nil + } + + // Finalizer is not needed anymore. + runtime.SetFinalizer(bw, nil) + + return err +} diff --git a/body_wrapper_test.go b/body_wrapper_test.go new file mode 100644 index 000000000..3b783f107 --- /dev/null +++ b/body_wrapper_test.go @@ -0,0 +1,72 @@ +package httpexpect + +import ( + "io/ioutil" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBodyWrapperRewind(t *testing.T) { + body := newMockBody("test_body") + + cancelled := false + cancelFn := func() { + cancelled = true + } + + wr := newBodyWrapper(body, cancelFn) + + b, err := ioutil.ReadAll(wr) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) + + assert.True(t, body.closed) + assert.True(t, cancelled) + + err = wr.Close() + assert.NoError(t, err) + + wr.Rewind() + + b, err = ioutil.ReadAll(wr) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) +} + +func TestBodyWrapperGetBody(t *testing.T) { + body := newMockBody("test_body") + + wr := newBodyWrapper(body, nil) + + rd1, err := wr.GetBody() + assert.NoError(t, err) + + rd2, err := wr.GetBody() + assert.NoError(t, err) + + b, err := ioutil.ReadAll(rd1) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) + + b, err = ioutil.ReadAll(rd2) + assert.NoError(t, err) + assert.Equal(t, "test_body", string(b)) +} + +func TestBodyWrapperClose(t *testing.T) { + body := newMockBody("test_body") + + cancelled := false + cancelFn := func() { + cancelled = true + } + + wr := newBodyWrapper(body, cancelFn) + + err := wr.Close() + assert.NoError(t, err) + + assert.True(t, body.closed) + assert.True(t, cancelled) +} diff --git a/e2e_printer_test.go b/e2e_printer_test.go new file mode 100644 index 000000000..4cd15e09f --- /dev/null +++ b/e2e_printer_test.go @@ -0,0 +1,82 @@ +package httpexpect + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func createPrinterHandler() http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + body, _ := ioutil.ReadAll(r.Body) + if string(body) != "test_request" { + panic("unexpected request body " + string(body)) + } + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte(`test_response`)) + }) + + return mux +} + +func TestE2EPrinter(t *testing.T) { + handler := createPrinterHandler() + + server := httptest.NewServer(handler) + defer server.Close() + + p := &mockPrinter{} + + e := WithConfig(Config{ + BaseURL: server.URL, + Reporter: NewAssertReporter(t), + Printers: []Printer{ + p, + }, + }) + + e.POST("/test"). + WithText("test_request"). + Expect(). + Text(). + Equal("test_response") + + assert.Equal(t, "test_request", string(p.reqBody)) + assert.Equal(t, "test_response", string(p.respBody)) +} + +func TestE2EPrinterMultiple(t *testing.T) { + handler := createPrinterHandler() + + server := httptest.NewServer(handler) + defer server.Close() + + p1 := &mockPrinter{} + p2 := &mockPrinter{} + + e := WithConfig(Config{ + BaseURL: server.URL, + Reporter: NewAssertReporter(t), + Printers: []Printer{ + p1, + p2, + }, + }) + + e.POST("/test"). + WithText("test_request"). + Expect(). + Text(). + Equal("test_response") + + assert.Equal(t, "test_request", string(p1.reqBody)) + assert.Equal(t, "test_response", string(p1.respBody)) + + assert.Equal(t, "test_request", string(p2.reqBody)) + assert.Equal(t, "test_response", string(p2.respBody)) +} diff --git a/e2e_timeout_test.go b/e2e_timeout_test.go new file mode 100644 index 000000000..489f4d2bf --- /dev/null +++ b/e2e_timeout_test.go @@ -0,0 +1,103 @@ +package httpexpect + +import ( + "math/rand" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randomString(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +func createTimeoutHandler() http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/sleep", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Duration(time.Second)) + }) + + mux.HandleFunc("/small", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`"`)) + _, _ = w.Write([]byte(randomString(10))) + _, _ = w.Write([]byte(`"`)) + }) + + mux.HandleFunc("/large", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`"`)) + _, _ = w.Write([]byte(randomString(1024 * 10))) + _, _ = w.Write([]byte(`"`)) + }) + + return mux +} + +func TestE2ETimeoutDeadlineExpired(t *testing.T) { + handler := createTimeoutHandler() + + server := httptest.NewServer(handler) + defer server.Close() + + r := newMockReporter(t) + + e := WithConfig(Config{ + BaseURL: server.URL, + Reporter: r, + }) + + e.GET("/sleep"). + WithTimeout(10 * time.Millisecond). + Expect() + + assert.True(t, r.reported) +} + +func TestE2ETimeoutSmallBody(t *testing.T) { + handler := createTimeoutHandler() + + server := httptest.NewServer(handler) + defer server.Close() + + e := New(t, server.URL) + + for i := 0; i < 100; i++ { + e.GET("/small"). + WithTimeout(20 * time.Minute). + Expect(). + Status(http.StatusOK). + JSON(). + String() + } +} + +func TestE2ETimeoutLargeBody(t *testing.T) { + handler := createTimeoutHandler() + + server := httptest.NewServer(handler) + defer server.Close() + + e := New(t, server.URL) + + for i := 0; i < 100; i++ { + e.GET("/large"). + WithTimeout(20 * time.Minute). + Expect(). + Status(http.StatusOK). + JSON(). + String() + } +} diff --git a/mocks_test.go b/mocks_test.go index 2dbacad0c..882569d5b 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -1,6 +1,9 @@ package httpexpect import ( + "bytes" + "io" + "io/ioutil" "net/http" "testing" "time" @@ -22,6 +25,23 @@ func (c *mockClient) Do(req *http.Request) (*http.Response, error) { return nil, c.err } +type mockBody struct { + io.Reader + closed bool +} + +func newMockBody(body string) *mockBody { + return &mockBody{ + Reader: bytes.NewBufferString(body), + closed: false, + } +} + +func (b *mockBody) Close() error { + b.closed = true + return nil +} + type mockReporter struct { testing *testing.T reported bool @@ -36,6 +56,27 @@ func (r *mockReporter) Errorf(message string, args ...interface{}) { r.reported = true } +type mockPrinter struct { + reqBody []byte + respBody []byte + rtt time.Duration +} + +func (p *mockPrinter) Request(req *http.Request) { + if req.Body != nil { + p.reqBody, _ = ioutil.ReadAll(req.Body) + req.Body.Close() + } +} + +func (p *mockPrinter) Response(resp *http.Response, rtt time.Duration) { + if resp.Body != nil { + p.respBody, _ = ioutil.ReadAll(resp.Body) + resp.Body.Close() + } + p.rtt = rtt +} + type mockWebsocketConn struct { msgType int readMsgErr error diff --git a/request.go b/request.go index e7ffea834..3329f95fe 100644 --- a/request.go +++ b/request.go @@ -1306,58 +1306,75 @@ func (r *Request) sendWebsocketRequest() ( return resp, conn, elapsed } -func (r *Request) retryRequest(reqFunc func() (resp *http.Response, err error)) ( - resp *http.Response, elapsed time.Duration, err error, +func (r *Request) retryRequest(reqFunc func() (*http.Response, error)) ( + *http.Response, time.Duration, error, ) { - var body []byte - if r.maxRetries > 0 && r.http.Body != nil && r.http.Body != http.NoBody { - body, _ = ioutil.ReadAll(r.http.Body) + if r.http.Body != nil && r.http.Body != http.NoBody { + if _, ok := r.http.Body.(*bodyWrapper); !ok { + r.http.Body = newBodyWrapper(r.http.Body, nil) + } } + reqBody, _ := r.http.Body.(*bodyWrapper) + delay := r.minRetryDelay i := 0 for { - if body != nil { - r.http.Body = ioutil.NopCloser(bytes.NewReader(body)) - } - for _, printer := range r.config.Printers { + if reqBody != nil { + reqBody.Rewind() + } printer.Request(r.http) } - func() { - if r.timeout > 0 { - var ctx context.Context - var cancel context.CancelFunc - if r.config.Context != nil { - ctx, cancel = context.WithTimeout(r.config.Context, r.timeout) - } else { - ctx, cancel = context.WithTimeout(context.Background(), r.timeout) - } + if reqBody != nil { + reqBody.Rewind() + } + + var cancelFn context.CancelFunc - defer cancel() - r.http = r.http.WithContext(ctx) + if r.timeout > 0 { + var ctx context.Context + if r.config.Context != nil { + ctx, cancelFn = context.WithTimeout(r.config.Context, r.timeout) + } else { + ctx, cancelFn = context.WithTimeout(context.Background(), r.timeout) } - start := time.Now() - resp, err = reqFunc() - elapsed = time.Since(start) - }() + r.http = r.http.WithContext(ctx) + } + + start := time.Now() + resp, err := reqFunc() + elapsed := time.Since(start) + + if resp != nil && resp.Body != nil { + resp.Body = newBodyWrapper(resp.Body, cancelFn) + } else if cancelFn != nil { + cancelFn() + } if resp != nil { for _, printer := range r.config.Printers { + if resp.Body != nil { + resp.Body.(*bodyWrapper).Rewind() + } printer.Response(resp, elapsed) } } i++ if i == r.maxRetries+1 { - return + return resp, elapsed, err } if !r.shouldRetry(resp, err) { - return + return resp, elapsed, err + } + + if resp != nil && resp.Body != nil { + resp.Body.Close() } time.Sleep(delay) @@ -1483,19 +1500,17 @@ func (r *Request) setupRedirects() { if r.redirectPolicy == FollowAllRedirects { if r.http.Body != nil && r.http.Body != http.NoBody { - bodyBytes, bodyErr := ioutil.ReadAll(r.http.Body) - - r.http.GetBody = func() (io.ReadCloser, error) { - if bodyErr != nil { - return nil, bodyErr - } - return ioutil.NopCloser(bytes.NewReader(bodyBytes)), nil + if _, ok := r.http.Body.(*bodyWrapper); !ok { + r.http.Body = newBodyWrapper(r.http.Body, nil) } + r.http.GetBody = r.http.Body.(*bodyWrapper).GetBody } else { r.http.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil } } + } else if r.redirectPolicy != defaultRedirectPolicy { + r.http.GetBody = nil } } diff --git a/response.go b/response.go index 984992ba5..066347b21 100644 --- a/response.go +++ b/response.go @@ -102,7 +102,17 @@ func getContent(chain *chain, resp *http.Response) []byte { return []byte{} } + if bw, ok := resp.Body.(*bodyWrapper); ok { + bw.Rewind() + } + content, err := ioutil.ReadAll(resp.Body) + + closeErr := resp.Body.Close() + if err == nil { + err = closeErr + } + if err != nil { chain.fail(err.Error()) return nil diff --git a/response_test.go b/response_test.go index c960f1184..9f159055a 100644 --- a/response_test.go +++ b/response_test.go @@ -1065,3 +1065,21 @@ func TestResponseContentOpts(t *testing.T) { }) }) } + +func TestResponseBodyClosing(t *testing.T) { + reporter := newMockReporter(t) + + body := newMockBody("test_body") + + httpResp := &http.Response{ + StatusCode: http.StatusOK, + Body: body, + } + + resp := NewResponse(reporter, httpResp) + + assert.Equal(t, "test_body", resp.Body().Raw()) + assert.True(t, body.closed) + + resp.chain.assertOK(t) +}