Skip to content

Commit

Permalink
Fix problems with body reading and closing
Browse files Browse the repository at this point in the history
  • Loading branch information
gavv committed Nov 14, 2022
1 parent 47a669e commit 18182d4
Show file tree
Hide file tree
Showing 9 changed files with 546 additions and 35 deletions.
4 changes: 2 additions & 2 deletions binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ func (binder Binder) RoundTrip(origReq *http.Request) (*http.Response, error) {
req.Proto = fmt.Sprintf("HTTP/%d.%d", req.ProtoMajor, req.ProtoMinor)
}

if req.Body != nil {
if req.Body != nil && req.Body != http.NoBody {
if req.ContentLength == -1 {
req.TransferEncoding = []string{"chunked"}
}
} else {
req.Body = ioutil.NopCloser(bytes.NewReader(nil))
req.Body = http.NoBody
}

if req.URL != nil && req.URL.Scheme == "https" && binder.TLS != nil {
Expand Down
170 changes: 170 additions & 0 deletions body_wrapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package httpexpect

import (
"bytes"
"context"
"io"
"io/ioutil"
"runtime"
"sync"
)

// Wrapper for request or response body reader
// Allows to read body multiple times using two approaches:
// - use Read to read body contents and Rewind to restart reading from beginning
// - use GetBody to get new reader for body contents
type bodyWrapper struct {
currReader io.Reader

origReader io.ReadCloser
origBytes []byte
origErr error

cancelFunc context.CancelFunc

isInitialized bool

mu sync.Mutex
}

func newBodyWrapper(reader io.ReadCloser, cancelFunc context.CancelFunc) *bodyWrapper {
bw := &bodyWrapper{
origReader: reader,
cancelFunc: cancelFunc,
}

// This is not strictly necessary because we should always call close.
// This is just a reinsurance.
runtime.SetFinalizer(bw, (*bodyWrapper).Close)

return bw
}

// Read body contents
func (bw *bodyWrapper) Read(p []byte) (n int, err error) {
bw.mu.Lock()
defer bw.mu.Unlock()

// Preserve original reader error
if bw.origErr != nil {
return 0, bw.origErr
}

// Lazy initialization
if !bw.isInitialized {
if err := bw.initialize(); err != nil {
return 0, err
}
}

if bw.currReader == nil {
bw.currReader = bytes.NewReader(bw.origBytes)
}
return bw.currReader.Read(p)
}

// Close body
func (bw *bodyWrapper) Close() error {
bw.mu.Lock()
defer bw.mu.Unlock()

err := bw.origErr

// Rewind or GetBody may be called later, so be sure to
// read body into memory before closing
if !bw.isInitialized {
initErr := bw.initialize()
if initErr != nil {
err = initErr
}
}

// Close original reader
closeErr := bw.closeAndCancel()
if closeErr != nil {
err = closeErr
}

return err
}

// Rewind reading to the beginning
func (bw *bodyWrapper) Rewind() {
bw.mu.Lock()
defer bw.mu.Unlock()

// Until first read, rewind is no-op
if !bw.isInitialized {
return
}

// Reset reader
bw.currReader = bytes.NewReader(bw.origBytes)
}

// Create new reader to retrieve body contents
// New reader always reads body from the beginning
// Does not affected by Rewind()
func (bw *bodyWrapper) GetBody() (io.ReadCloser, error) {
bw.mu.Lock()
defer bw.mu.Unlock()

// Preserve original reader error
if bw.origErr != nil {
return nil, bw.origErr
}

// Lazy initialization
if !bw.isInitialized {
if err := bw.initialize(); err != nil {
return nil, err
}
}

return ioutil.NopCloser(bytes.NewReader(bw.origBytes)), nil
}

func (bw *bodyWrapper) initialize() error {
if !bw.isInitialized {
bw.isInitialized = true

if bw.origReader != nil {
bw.origBytes, bw.origErr = ioutil.ReadAll(bw.origReader)

_ = bw.closeAndCancel()

if bw.origErr != nil {
return bw.origErr
}
}
}

return nil
}

func (bw *bodyWrapper) closeAndCancel() error {
if bw.origReader == nil && bw.cancelFunc == nil {
return nil
}

var err error

if bw.origReader != nil {
err = bw.origReader.Close()
bw.origReader = nil

if bw.origErr == nil {
bw.origErr = err
}
}

if bw.cancelFunc != nil {
bw.cancelFunc()
bw.cancelFunc = nil
}

// Finalizer is not needed anymore.
runtime.SetFinalizer(bw, nil)

return err
}
72 changes: 72 additions & 0 deletions body_wrapper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package httpexpect

import (
"io/ioutil"
"testing"

"github.com/stretchr/testify/assert"
)

func TestBodyWrapperRewind(t *testing.T) {
body := newMockBody("test_body")

cancelled := false
cancelFn := func() {
cancelled = true
}

wr := newBodyWrapper(body, cancelFn)

b, err := ioutil.ReadAll(wr)
assert.NoError(t, err)
assert.Equal(t, "test_body", string(b))

assert.True(t, body.closed)
assert.True(t, cancelled)

err = wr.Close()
assert.NoError(t, err)

wr.Rewind()

b, err = ioutil.ReadAll(wr)
assert.NoError(t, err)
assert.Equal(t, "test_body", string(b))
}

func TestBodyWrapperGetBody(t *testing.T) {
body := newMockBody("test_body")

wr := newBodyWrapper(body, nil)

rd1, err := wr.GetBody()
assert.NoError(t, err)

rd2, err := wr.GetBody()
assert.NoError(t, err)

b, err := ioutil.ReadAll(rd1)
assert.NoError(t, err)
assert.Equal(t, "test_body", string(b))

b, err = ioutil.ReadAll(rd2)
assert.NoError(t, err)
assert.Equal(t, "test_body", string(b))
}

func TestBodyWrapperClose(t *testing.T) {
body := newMockBody("test_body")

cancelled := false
cancelFn := func() {
cancelled = true
}

wr := newBodyWrapper(body, cancelFn)

err := wr.Close()
assert.NoError(t, err)

assert.True(t, body.closed)
assert.True(t, cancelled)
}
82 changes: 82 additions & 0 deletions e2e_printer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package httpexpect

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func createPrinterHandler() http.Handler {
mux := http.NewServeMux()

mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
body, _ := ioutil.ReadAll(r.Body)
if string(body) != "test_request" {
panic("unexpected request body " + string(body))
}
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(`test_response`))
})

return mux
}

func TestE2EPrinter(t *testing.T) {
handler := createPrinterHandler()

server := httptest.NewServer(handler)
defer server.Close()

p := &mockPrinter{}

e := WithConfig(Config{
BaseURL: server.URL,
Reporter: NewAssertReporter(t),
Printers: []Printer{
p,
},
})

e.POST("/test").
WithText("test_request").
Expect().
Text().
Equal("test_response")

assert.Equal(t, "test_request", string(p.reqBody))
assert.Equal(t, "test_response", string(p.respBody))
}

func TestE2EPrinterMultiple(t *testing.T) {
handler := createPrinterHandler()

server := httptest.NewServer(handler)
defer server.Close()

p1 := &mockPrinter{}
p2 := &mockPrinter{}

e := WithConfig(Config{
BaseURL: server.URL,
Reporter: NewAssertReporter(t),
Printers: []Printer{
p1,
p2,
},
})

e.POST("/test").
WithText("test_request").
Expect().
Text().
Equal("test_response")

assert.Equal(t, "test_request", string(p1.reqBody))
assert.Equal(t, "test_response", string(p1.respBody))

assert.Equal(t, "test_request", string(p2.reqBody))
assert.Equal(t, "test_response", string(p2.respBody))
}
Loading

0 comments on commit 18182d4

Please sign in to comment.