diff --git a/file.go b/file.go index 9224378..e738b2e 100644 --- a/file.go +++ b/file.go @@ -2,9 +2,12 @@ package httptest import ( "bytes" + "fmt" "io" "mime/multipart" "net/http" + "net/textproto" + "strings" ) type File struct { @@ -29,19 +32,35 @@ func (r *Request) MultiPartPut(body interface{}, files ...File) (*Response, erro return r.Perform(req), nil } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + // this helper method was inspired by this blog post by Matt Aimonetti: // https://matt.aimonetti.net/posts/2013/07/01/golang-multipart-file-upload-example/ func newMultipart(url string, method string, body interface{}, files ...File) (*http.Request, error) { - bb := &bytes.Buffer{} writer := multipart.NewWriter(bb) defer writer.Close() for _, f := range files { - part, err := writer.CreateFormFile(f.ParamName, f.FileName) + fBuffer, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", + fmt.Sprintf(`form-data; name="%s"; filename="%s"`, + escapeQuotes(f.ParamName), escapeQuotes(f.FileName))) + h.Set("Content-Type", http.DetectContentType(fBuffer)) + part, err := writer.CreatePart(h) if err != nil { return nil, err } - _, err = io.Copy(part, f) + fReader := bytes.NewReader(fBuffer) + _, err = io.Copy(part, fReader) if err != nil { return nil, err } diff --git a/file_test.go b/file_test.go index 84e9a2c..4eb4655 100644 --- a/file_test.go +++ b/file_test.go @@ -1,7 +1,8 @@ package httptest import ( - "os" + "fmt" + "github.com/gobuffalo/httptest/testassets" "testing" "github.com/stretchr/testify/require" @@ -11,20 +12,27 @@ func Test_FileUpload(t *testing.T) { r := require.New(t) w := New(App()) - f := struct { - Name string - }{"Foo"} + foo := func(filename, expectedType string) { + f := struct { + Name string + }{"Foo"} - rr, err := os.Open("./file_test.go") - r.NoError(err) - wf := File{ - ParamName: "MyFile", - FileName: "foo.go", - Reader: rr, + rr, err := testassets.FS.Open(filename) + r.NoError(err) + wf := File{ + ParamName: "MyFile", + FileName: filename, + Reader: rr, + } + res, err := w.HTML("/up").MultiPartPost(f, wf) + r.NoError(err) + r.Equal(200, res.Code) + r.Equal(fmt.Sprintf("Foo\n%s\n%s\n", filename, expectedType), res.Body.String()) } - res, err := w.HTML("/up").MultiPartPost(f, wf) - r.NoError(err) - r.Equal(200, res.Code) - r.Equal("Foo\nfoo.go\n", res.Body.String()) + foo("test.jpg", "image/jpeg") + foo("test.png", "image/png") + foo("test.pdf", "application/pdf") + foo("embed.go", "text/plain; charset=utf-8") + foo("random.bin", "application/octet-stream") } diff --git a/httptest_test.go b/httptest_test.go index c6b6bef..44f9e6a 100644 --- a/httptest_test.go +++ b/httptest_test.go @@ -95,6 +95,7 @@ func App() http.Handler { } fmt.Fprintln(res, req.FormValue("Name")) fmt.Fprintln(res, h.Filename) + fmt.Fprintln(res, h.Header.Get("Content-Type")) }) return p } diff --git a/testassets/embed.go b/testassets/embed.go new file mode 100644 index 0000000..5672ed0 --- /dev/null +++ b/testassets/embed.go @@ -0,0 +1,8 @@ +package testassets + +import ( + "embed" +) + +//go:embed * +var FS embed.FS diff --git a/testassets/random.bin b/testassets/random.bin new file mode 100644 index 0000000..8cfe9fb --- /dev/null +++ b/testassets/random.bin @@ -0,0 +1 @@ +># n-p8 \ No newline at end of file diff --git a/testassets/test.jpg b/testassets/test.jpg new file mode 100644 index 0000000..a9499c4 Binary files /dev/null and b/testassets/test.jpg differ diff --git a/testassets/test.pdf b/testassets/test.pdf new file mode 100644 index 0000000..dfcb106 Binary files /dev/null and b/testassets/test.pdf differ diff --git a/testassets/test.png b/testassets/test.png new file mode 100644 index 0000000..4e29bfd Binary files /dev/null and b/testassets/test.png differ