diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ebaab1ac..db527b51 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,6 +3,7 @@ name: CI on: push: branches: + - v3 - v2 paths-ignore: - '**.md' @@ -11,6 +12,7 @@ on: pull_request: branches: - main + - v3 - v2 paths-ignore: - '**.md' @@ -47,7 +49,7 @@ jobs: run: diff -u <(echo -n) <(go fmt $(go list ./...)) - name: Test - run: go test ./... -race -count=1 -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... -shuffle=on + run: go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=1 -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... -shuffle=on - name: Upload coverage to Codecov if: ${{ matrix.os == 'ubuntu-latest' && matrix.go == 'stable' }} diff --git a/.github/workflows/label-actions.yml b/.github/workflows/label-actions.yml index 8fee6ffa..515b6f7b 100644 --- a/.github/workflows/label-actions.yml +++ b/.github/workflows/label-actions.yml @@ -36,7 +36,7 @@ jobs: run: diff -u <(echo -n) <(go fmt $(go list ./...)) - name: Test - run: go test ./... -race -count=1 -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... -shuffle=on + run: go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=1 -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... -shuffle=on - name: Upload coverage to Codecov if: ${{ matrix.os == 'ubuntu-latest' && matrix.go == 'stable' }} diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 00000000..d90a54cc --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,205 @@ +// Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. +// resty source code and usage is governed by a MIT style +// license that can be found in the LICENSE file. +// SPDX-License-Identifier: MIT + +package resty + +import ( + "bytes" + "strings" + "testing" +) + +func Benchmark_parseRequestURL_PathParams(b *testing.B) { + c := New().SetPathParams(map[string]string{ + "foo": "1", + "bar": "2", + }).SetRawPathParams(map[string]string{ + "foo": "3", + "xyz": "4", + }) + r := c.R().SetPathParams(map[string]string{ + "foo": "5", + "qwe": "6", + }).SetRawPathParams(map[string]string{ + "foo": "7", + "asd": "8", + }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.URL = "https://example.com/{foo}/{bar}/{xyz}/{qwe}/{asd}" + if err := parseRequestURL(c, r); err != nil { + b.Errorf("parseRequestURL() error = %v", err) + } + } +} + +func Benchmark_parseRequestURL_QueryParams(b *testing.B) { + c := New().SetQueryParams(map[string]string{ + "foo": "1", + "bar": "2", + }) + r := c.R().SetQueryParams(map[string]string{ + "foo": "5", + "qwe": "6", + }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.URL = "https://example.com/" + if err := parseRequestURL(c, r); err != nil { + b.Errorf("parseRequestURL() error = %v", err) + } + } +} + +func Benchmark_parseRequestHeader(b *testing.B) { + c := New() + r := c.R() + c.SetHeaders(map[string]string{ + "foo": "1", // ignored, because of the same header in the request + "bar": "2", + }) + r.SetHeaders(map[string]string{ + "foo": "3", + "xyz": "4", + }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestHeader(c, r); err != nil { + b.Errorf("parseRequestHeader() error = %v", err) + } + } +} + +func Benchmark_parseRequestBody_string(b *testing.B) { + c := New() + r := c.R() + r.SetBody("foo").SetContentLength(true) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestBody(c, r); err != nil { + b.Errorf("parseRequestBody() error = %v", err) + } + } +} + +func Benchmark_parseRequestBody_byte(b *testing.B) { + c := New() + r := c.R() + r.SetBody([]byte("foo")).SetContentLength(true) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestBody(c, r); err != nil { + b.Errorf("parseRequestBody() error = %v", err) + } + } +} + +func Benchmark_parseRequestBody_reader(b *testing.B) { + c := New() + r := c.R() + r.SetBody(bytes.NewBufferString("foo")) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestBody(c, r); err != nil { + b.Errorf("parseRequestBody() error = %v", err) + } + } +} + +func Benchmark_parseRequestBody_struct(b *testing.B) { + type FooBar struct { + Foo string `json:"foo"` + Bar string `json:"bar"` + } + c := New() + r := c.R() + r.SetBody(FooBar{Foo: "1", Bar: "2"}).SetContentLength(true).SetHeader(hdrContentTypeKey, jsonContentType) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestBody(c, r); err != nil { + b.Errorf("parseRequestBody() error = %v", err) + } + } +} + +func Benchmark_parseRequestBody_struct_xml(b *testing.B) { + type FooBar struct { + Foo string `xml:"foo"` + Bar string `xml:"bar"` + } + c := New() + r := c.R() + r.SetBody(FooBar{Foo: "1", Bar: "2"}).SetContentLength(true).SetHeader(hdrContentTypeKey, "text/xml") + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestBody(c, r); err != nil { + b.Errorf("parseRequestBody() error = %v", err) + } + } +} + +func Benchmark_parseRequestBody_map(b *testing.B) { + c := New() + r := c.R() + r.SetBody(map[string]string{ + "foo": "1", + "bar": "2", + }).SetContentLength(true).SetHeader(hdrContentTypeKey, jsonContentType) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestBody(c, r); err != nil { + b.Errorf("parseRequestBody() error = %v", err) + } + } +} + +func Benchmark_parseRequestBody_slice(b *testing.B) { + c := New() + r := c.R() + r.SetBody([]string{"1", "2"}).SetContentLength(true).SetHeader(hdrContentTypeKey, jsonContentType) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestBody(c, r); err != nil { + b.Errorf("parseRequestBody() error = %v", err) + } + } +} + +func Benchmark_parseRequestBody_FormData(b *testing.B) { + c := New() + r := c.R() + c.SetFormData(map[string]string{"foo": "1", "bar": "2"}) + r.SetFormData(map[string]string{"foo": "3", "baz": "4"}).SetContentLength(true) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestBody(c, r); err != nil { + b.Errorf("parseRequestBody() error = %v", err) + } + } +} + +func Benchmark_parseRequestBody_MultiPart(b *testing.B) { + c := New() + r := c.R() + c.SetFormData(map[string]string{"foo": "1", "bar": "2"}) + r.SetFormData(map[string]string{"foo": "3", "baz": "4"}). + SetMultipartFormData(map[string]string{"foo": "5", "xyz": "6"}). + SetFileReader("qwe", "qwe.txt", strings.NewReader("7")). + SetMultipartFields( + &MultipartField{ + Name: "sdj", + ContentType: "text/plain", + Reader: strings.NewReader("8"), + }, + ). + SetContentLength(true). + SetMethod(MethodPost) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := parseRequestBody(c, r); err != nil { + b.Errorf("parseRequestBody() error = %v", err) + } + } +} diff --git a/client.go b/client.go index fd174ead..45d2ef2e 100644 --- a/client.go +++ b/client.go @@ -58,7 +58,6 @@ var ( ErrUnsupportedRequestBodyKind = errors.New("resty: unsupported request body kind") hdrUserAgentKey = http.CanonicalHeaderKey("User-Agent") - hdrAcceptKey = http.CanonicalHeaderKey("Accept") hdrAcceptEncodingKey = http.CanonicalHeaderKey("Accept-Encoding") hdrContentTypeKey = http.CanonicalHeaderKey("Content-Type") hdrContentLengthKey = http.CanonicalHeaderKey("Content-Length") @@ -172,7 +171,6 @@ type Client struct { queryParams url.Values formData url.Values pathParams map[string]string - rawPathParams map[string]string header http.Header credentials *credentials authToken string @@ -220,8 +218,8 @@ type Client struct { successHooks []SuccessHook contentTypeEncoders map[string]ContentTypeEncoder contentTypeDecoders map[string]ContentTypeDecoder - contentDecompressorKeys []string - contentDecompressors map[string]ContentDecompressor + contentDecompresserKeys []string + contentDecompressers map[string]ContentDecompresser certWatcherStopChan chan bool } @@ -258,7 +256,7 @@ func (c *Client) SetBaseURL(url string) *Client { return c } -// LoadBalancer method returns the requestload balancer instance from the client +// LoadBalancer method returns the request load balancer instance from the client // instance. Otherwise returns nil. func (c *Client) LoadBalancer() LoadBalancer { c.lock.RLock() @@ -281,17 +279,17 @@ func (c *Client) Header() http.Header { return c.header } -// SetHeader method sets a single header field and its value in the client instance. -// These headers will be applied to all requests from this client instance. +// SetHeader method sets a single header and its value in the client instance. +// These headers will be applied to all requests raised from the client instance. // Also, it can be overridden by request-level header options. // -// See [Request.SetHeader] or [Request.SetHeaders]. -// // For Example: To set `Content-Type` and `Accept` as `application/json` // // client. // SetHeader("Content-Type", "application/json"). // SetHeader("Accept", "application/json") +// +// See [Request.SetHeader] or [Request.SetHeaders]. func (c *Client) SetHeader(header, value string) *Client { c.lock.Lock() defer c.lock.Unlock() @@ -299,11 +297,9 @@ func (c *Client) SetHeader(header, value string) *Client { return c } -// SetHeaders method sets multiple header fields and their values at one go in the client instance. -// These headers will be applied to all requests from this client instance. Also, it can be -// overridden at request level headers options. -// -// See [Request.SetHeaders] or [Request.SetHeader]. +// SetHeaders method sets multiple headers and their values at one go, and +// these headers will be applied to all requests raised from the client instance. +// Also, it can be overridden at request-level headers options. // // For Example: To set `Content-Type` and `Accept` as `application/json` // @@ -311,6 +307,8 @@ func (c *Client) SetHeader(header, value string) *Client { // "Content-Type": "application/json", // "Accept": "application/json", // }) +// +// See [Request.SetHeaders] or [Request.SetHeader]. func (c *Client) SetHeaders(headers map[string]string) *Client { c.lock.Lock() defer c.lock.Unlock() @@ -329,6 +327,8 @@ func (c *Client) SetHeaders(headers map[string]string) *Client { // SetHeaderVerbatim("all_lowercase", "available"). // SetHeaderVerbatim("UPPERCASE", "available"). // SetHeaderVerbatim("x-cloud-trace-id", "798e94019e5fc4d57fbb8901eb4c6cae") +// +// See [Request.SetHeaderVerbatim]. func (c *Client) SetHeaderVerbatim(header, value string) *Client { c.lock.Lock() defer c.lock.Unlock() @@ -612,7 +612,6 @@ func (c *Client) R() *Request { Header: http.Header{}, Cookies: make([]*http.Cookie, 0), PathParams: make(map[string]string), - RawPathParams: make(map[string]string), Timeout: c.timeout, Debug: c.debug, IsTrace: c.isTrace, @@ -885,49 +884,49 @@ func (c *Client) inferContentTypeDecoder(ct ...string) (ContentTypeDecoder, bool return nil, false } -// ContentDecompressors method returns all the registered content-encoding decompressors. -func (c *Client) ContentDecompressors() map[string]ContentDecompressor { +// ContentDecompressers method returns all the registered content-encoding Decompressers. +func (c *Client) ContentDecompressers() map[string]ContentDecompresser { c.lock.RLock() defer c.lock.RUnlock() - return c.contentDecompressors + return c.contentDecompressers } -// AddContentDecompressor method adds the user-provided Content-Encoding ([RFC 9110]) decompressor +// AddContentDecompresser method adds the user-provided Content-Encoding ([RFC 9110]) Decompresser // and directive into a client. // -// NOTE: It overwrites the decompressor function if the given Content-Encoding directive already exists. +// NOTE: It overwrites the Decompresser function if the given Content-Encoding directive already exists. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110 -func (c *Client) AddContentDecompressor(k string, d ContentDecompressor) *Client { +func (c *Client) AddContentDecompresser(k string, d ContentDecompresser) *Client { c.lock.Lock() defer c.lock.Unlock() - if !slices.Contains(c.contentDecompressorKeys, k) { - c.contentDecompressorKeys = slices.Insert(c.contentDecompressorKeys, 0, k) + if !slices.Contains(c.contentDecompresserKeys, k) { + c.contentDecompresserKeys = slices.Insert(c.contentDecompresserKeys, 0, k) } - c.contentDecompressors[k] = d + c.contentDecompressers[k] = d return c } -// ContentDecompressorKeys method returns all the registered content-encoding decompressors +// ContentDecompresserKeys method returns all the registered content-encoding Decompressers // keys as comma-separated string. -func (c *Client) ContentDecompressorKeys() string { +func (c *Client) ContentDecompresserKeys() string { c.lock.RLock() defer c.lock.RUnlock() - return strings.Join(c.contentDecompressorKeys, ", ") + return strings.Join(c.contentDecompresserKeys, ", ") } -// SetContentDecompressorKeys method sets given Content-Encoding ([RFC 9110]) directives into the client instance. +// SetContentDecompresserKeys method sets given Content-Encoding ([RFC 9110]) directives into the client instance. // -// It checks the given Content-Encoding exists in the [ContentDecompressor] list before assigning it, +// It checks the given Content-Encoding exists in the [ContentDecompresser] list before assigning it, // if it does not exist, it will skip that directive. // -// Use this method to overwrite the default order. If a new content decompressor is added, +// Use this method to overwrite the default order. If a new content Decompresser is added, // that directive will be the first. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110 -func (c *Client) SetContentDecompressorKeys(keys []string) *Client { +func (c *Client) SetContentDecompresserKeys(keys []string) *Client { result := make([]string, 0) - decoders := c.ContentDecompressors() + decoders := c.ContentDecompressers() for _, k := range keys { if _, f := decoders[k]; f { result = append(result, k) @@ -936,7 +935,7 @@ func (c *Client) SetContentDecompressorKeys(keys []string) *Client { c.lock.Lock() defer c.lock.Unlock() - c.contentDecompressorKeys = result + c.contentDecompresserKeys = result return c } @@ -1379,9 +1378,7 @@ func (c *Client) AddRetryHook(hook RetryHookFunc) *Client { func (c *Client) TLSClientConfig() *tls.Config { cfg, err := c.tlsConfig() if err != nil { - c.lock.RLock() - c.log.Errorf("%v", err) - c.lock.RUnlock() + c.Logger().Errorf("%v", err) } return cfg } @@ -1435,20 +1432,20 @@ func (c *Client) ProxyURL() *url.URL { func (c *Client) SetProxy(proxyURL string) *Client { transport, err := c.HTTPTransport() if err != nil { - c.log.Errorf("%v", err) + c.Logger().Errorf("%v", err) return c } pURL, err := url.Parse(proxyURL) if err != nil { - c.log.Errorf("%v", err) + c.Logger().Errorf("%v", err) return c } c.lock.Lock() c.proxyURL = pURL + transport.Proxy = http.ProxyURL(c.proxyURL) c.lock.Unlock() - transport.Proxy = http.ProxyURL(c.ProxyURL()) return c } @@ -1458,7 +1455,7 @@ func (c *Client) SetProxy(proxyURL string) *Client { func (c *Client) RemoveProxy() *Client { transport, err := c.HTTPTransport() if err != nil { - c.log.Errorf("%v", err) + c.Logger().Errorf("%v", err) return c } @@ -1473,7 +1470,7 @@ func (c *Client) RemoveProxy() *Client { func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { config, err := c.tlsConfig() if err != nil { - c.log.Errorf("%v", err) + c.Logger().Errorf("%v", err) return c } @@ -1489,7 +1486,7 @@ func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { func (c *Client) SetRootCertificate(pemFilePath string) *Client { rootPemData, err := os.ReadFile(pemFilePath) if err != nil { - c.log.Errorf("%v", err) + c.Logger().Errorf("%v", err) return c } c.handleCAs("root", rootPemData) @@ -1527,7 +1524,7 @@ func (c *Client) SetRootCertificateFromString(pemCerts string) *Client { func (c *Client) SetClientRootCertificate(pemFilePath string) *Client { rootPemData, err := os.ReadFile(pemFilePath) if err != nil { - c.log.Errorf("%v", err) + c.Logger().Errorf("%v", err) return c } c.handleCAs("client", rootPemData) @@ -1559,7 +1556,7 @@ func (c *Client) initCertWatcher(pemFilePath, scope string, options *CertWatcher ticker := time.NewTicker(tickerDuration) st, err := os.Stat(pemFilePath) if err != nil { - c.log.Errorf("%v", err) + c.Logger().Errorf("%v", err) return } @@ -1576,7 +1573,7 @@ func (c *Client) initCertWatcher(pemFilePath, scope string, options *CertWatcher st, err = os.Stat(pemFilePath) if err != nil { - c.log.Errorf("%v", err) + c.Logger().Errorf("%v", err) continue } newModTime := st.ModTime().UTC() @@ -1615,7 +1612,7 @@ func (c *Client) SetClientRootCertificateFromString(pemCerts string) *Client { func (c *Client) handleCAs(scope string, permCerts []byte) { config, err := c.tlsConfig() if err != nil { - c.log.Errorf("%v", err) + c.Logger().Errorf("%v", err) return } @@ -1770,7 +1767,7 @@ func (c *Client) PathParams() map[string]string { func (c *Client) SetPathParam(param, value string) *Client { c.lock.Lock() defer c.lock.Unlock() - c.pathParams[param] = value + c.pathParams[param] = url.PathEscape(value) return c } @@ -1799,44 +1796,31 @@ func (c *Client) SetPathParams(params map[string]string) *Client { return c } -// RawPathParams method returns the raw path parameters from the client. -func (c *Client) RawPathParams() map[string]string { - c.lock.RLock() - defer c.lock.RUnlock() - return c.rawPathParams -} - // SetRawPathParam method sets a single URL path key-value pair in the -// Resty client instance. -// -// client.SetPathParam("userId", "sample@sample.com") -// -// Result: -// URL - /v1/users/{userId}/details -// Composed URL - /v1/users/sample@sample.com/details +// Resty client instance without path escape. // -// client.SetPathParam("path", "groups/developers") +// client.SetRawPathParam("path", "groups/developers") // // Result: -// URL - /v1/users/{userId}/details -// Composed URL - /v1/users/groups%2Fdevelopers/details +// URL - /v1/users/{userId}/details +// Composed URL - /v1/users/groups/developers/details // // It replaces the value of the key while composing the request URL. -// The value will be used as it is and will not be escaped. +// The value will be used as-is, no path escape applied. // // It can be overridden at the request level, // see [Request.SetRawPathParam] or [Request.SetRawPathParams] func (c *Client) SetRawPathParam(param, value string) *Client { c.lock.Lock() defer c.lock.Unlock() - c.rawPathParams[param] = value + c.pathParams[param] = value return c } // SetRawPathParams method sets multiple URL path key-value pairs at one go in the -// Resty client instance. +// Resty client instance without path escape. // -// client.SetPathParams(map[string]string{ +// client.SetRawPathParams(map[string]string{ // "userId": "sample@sample.com", // "subAccountId": "100002", // "path": "groups/developers", @@ -1847,7 +1831,7 @@ func (c *Client) SetRawPathParam(param, value string) *Client { // Composed URL - /v1/users/sample@sample.com/100002/groups/developers/details // // It replaces the value of the key while composing the request URL. -// The values will be used as they are and will not be escaped. +// The value will be used as-is, no path escape applied. // // It can be overridden at the request level, // see [Request.SetRawPathParam] or [Request.SetRawPathParams] @@ -2035,12 +2019,11 @@ func (c *Client) Clone(ctx context.Context) *Client { cc.formData = cloneURLValues(c.formData) cc.header = c.header.Clone() cc.pathParams = maps.Clone(c.pathParams) - cc.rawPathParams = maps.Clone(c.rawPathParams) cc.credentials = c.credentials.Clone() cc.contentTypeEncoders = maps.Clone(c.contentTypeEncoders) cc.contentTypeDecoders = maps.Clone(c.contentTypeDecoders) - cc.contentDecompressors = maps.Clone(c.contentDecompressors) - copy(cc.contentDecompressorKeys, c.contentDecompressorKeys) + cc.contentDecompressers = maps.Clone(c.contentDecompressers) + copy(cc.contentDecompresserKeys, c.contentDecompresserKeys) if c.proxyURL != nil { cc.proxyURL, _ = url.Parse(c.proxyURL.String()) @@ -2104,7 +2087,7 @@ func (c *Client) execute(req *Request) (*Response, error) { } if resp != nil { response.Body = resp.Body - if err = response.wrapContentDecompressor(); err != nil { + if err = response.wrapContentDecompresser(); err != nil { return response, err } @@ -2153,7 +2136,7 @@ func (c *Client) tlsConfig() (*tls.Config, error) { // just an internal helper method func (c *Client) outputLogTo(w io.Writer) *Client { - c.log.(*logger).l.SetOutput(w) + c.Logger().(*logger).l.SetOutput(w) return c } @@ -2210,12 +2193,8 @@ func (c *Client) onInvalidHooks(req *Request, err error) { } } -func (c *Client) debugf(format string, v ...interface{}) { - c.lock.RLock() - defer c.lock.RUnlock() - if !c.debug { - return +func (c *Client) debugf(format string, v ...any) { + if c.IsDebug() { + c.Logger().Debugf(format, v...) } - - c.log.Debugf(format, v...) } diff --git a/client_test.go b/client_test.go index 7d34ee53..42db25d4 100644 --- a/client_test.go +++ b/client_test.go @@ -281,6 +281,9 @@ type CustomRoundTripper2 struct { // RoundTrip just for test func (rt *CustomRoundTripper2) RoundTrip(_ *http.Request) (*http.Response, error) { + if rt.returnErr { + return nil, errors.New("test req mock error") + } return &http.Response{}, nil } @@ -541,18 +544,18 @@ func TestClientPreRequestMiddlewares(t *testing.T) { client := dcnl() fnPreRequestMiddleware1 := func(c *Client, r *Request) error { - c.log.Debugf("I'm in Pre-Request Hook") + c.Logger().Debugf("I'm in Pre-Request Hook") return nil } fnPreRequestMiddleware2 := func(c *Client, r *Request) error { - c.log.Debugf("I'm Overwriting existing Pre-Request Hook") + c.Logger().Debugf("I'm Overwriting existing Pre-Request Hook") // Reading Request `N` no of times for i := 0; i < 5; i++ { b, _ := r.RawRequest.GetBody() rb, _ := io.ReadAll(b) - c.log.Debugf("%s %v", string(rb), len(rb)) + c.Logger().Debugf("%s %v", string(rb), len(rb)) assertEqual(t, true, len(rb) >= 45) } return nil @@ -816,17 +819,17 @@ func TestLzwCompress(t *testing.T) { // Not found scenario _, err := c.R().Get(ts.URL + "/lzw-test") assertNotNil(t, err) - assertEqual(t, ErrContentDecompressorNotFound, err) + assertEqual(t, ErrContentDecompresserNotFound, err) // Register LZW content decoder - c.AddContentDecompressor("compress", func(r io.ReadCloser) (io.ReadCloser, error) { + c.AddContentDecompresser("compress", func(r io.ReadCloser) (io.ReadCloser, error) { l := &lzwReader{ s: r, r: lzw.NewReader(r, lzw.LSB, 8), } return l, nil }) - c.SetContentDecompressorKeys([]string{"compress"}) + c.SetContentDecompresserKeys([]string{"compress"}) testcases := []struct{ url, want string }{ {ts.URL + "/lzw-test", "This is LZW response testing"}, diff --git a/curl_test.go b/curl_test.go index f380786e..0fdf6df2 100644 --- a/curl_test.go +++ b/curl_test.go @@ -198,7 +198,7 @@ func TestCurl_buildCurlCmd(t *testing.T) { req.SetHeader(k, v) } - err := createHTTPRequest(c, req) + err := createRawRequest(c, req) assertNil(t, err) if len(tt.cookies) > 0 { @@ -249,7 +249,20 @@ func TestCurlRequestGetBodyError(t *testing.T) { } else { t.Log("curlCmdUnexecuted: \n", curlCmdUnexecuted) } +} + +func TestCurlRequestMiddlewaresError(t *testing.T) { + errMsg := "middleware error" + c := dcnl().EnableDebug(). + SetRequestMiddlewares( + func(c *Client, r *Request) error { + return errors.New(errMsg) + }, + PrepareRequestMiddleware, + ) + curlCmdUnexecuted := c.R().EnableGenerateCurlOnDebug().GenerateCurlCommand() + assertEqual(t, "", curlCmdUnexecuted) } func TestCurlMiscTestCoverage(t *testing.T) { diff --git a/digest.go b/digest.go index 98ac6d73..3a58106c 100644 --- a/digest.go +++ b/digest.go @@ -59,15 +59,18 @@ func (dt *digestTransport) RoundTrip(req *http.Request) (*http.Response, error) // make a request to get the 401 that contains the challenge. res, err := dt.transport.RoundTrip(req1) - if err != nil || res.StatusCode != http.StatusUnauthorized { - return res, err + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusUnauthorized { + return res, nil } _, _ = ioCopy(io.Discard, res.Body) closeq(res.Body) chaHdrValue := strings.TrimSpace(res.Header.Get(hdrWwwAuthenticateKey)) if chaHdrValue == "" { - return res, ErrDigestBadChallenge + return nil, ErrDigestBadChallenge } cha, err := dt.parseChallenge(chaHdrValue) diff --git a/digest_test.go b/digest_test.go index 76f8410a..6c1d41fe 100644 --- a/digest_test.go +++ b/digest_test.go @@ -160,6 +160,24 @@ func TestClientDigestAuthWithBodyQopAuthIntIoCopyError(t *testing.T) { assertEqual(t, 0, resp.StatusCode()) } +func TestClientDigestAuthRoundTripError(t *testing.T) { + conf := *defaultDigestServerConf() + ts := createDigestServer(t, &conf) + defer ts.Close() + + c := dcnl().SetTransport(&CustomRoundTripper2{returnErr: true}) + c.SetDigestAuth(conf.username, conf.password) + + _, err := c.R(). + SetResult(&AuthSuccess{}). + SetHeader(hdrContentTypeKey, "application/json"). + SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). + Post(ts.URL + conf.uri) + + assertNotNil(t, err) + assertEqual(t, true, strings.Contains(err.Error(), "test req mock error")) +} + func TestClientDigestAuthWithBodyQopAuthIntGetBodyNil(t *testing.T) { conf := *defaultDigestServerConf() conf.qop = "auth-int" diff --git a/go.mod b/go.mod index 742b33a9..b575af94 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,4 @@ module resty.dev/v3 go 1.21 -require ( - golang.org/x/net v0.27.0 - golang.org/x/time v0.6.0 -) +require golang.org/x/net v0.27.0 diff --git a/go.sum b/go.sum index 66793eb1..75667142 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,2 @@ golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= -golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= -golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/middleware.go b/middleware.go index 49bf50ce..94be3b8b 100644 --- a/middleware.go +++ b/middleware.go @@ -11,6 +11,7 @@ import ( "io" "mime/multipart" "net/http" + "net/textproto" "net/url" "path/filepath" "reflect" @@ -26,9 +27,7 @@ const debugRequestLogKey = "__restyDebugRequestLog" // PrepareRequestMiddleware method is used to prepare HTTP requests from // user provides request values. Request preparation fails if any error occurs -func PrepareRequestMiddleware(c *Client, r *Request) error { - var err error - +func PrepareRequestMiddleware(c *Client, r *Request) (err error) { if err = parseRequestURL(c, r); err != nil { return err } @@ -40,9 +39,9 @@ func PrepareRequestMiddleware(c *Client, r *Request) error { return err } - if err = createHTTPRequest(c, r); err != nil { - return err - } + // at this point, possible error from `http.NewRequestWithContext` + // is URL-related, and those get caught up in the `parseRequestURL` + createRawRequest(c, r) // last one doesn't need if condition return addCredentials(c, r) @@ -62,76 +61,58 @@ func GenerateCurlRequestMiddleware(c *Client, r *Request) (err error) { } func parseRequestURL(c *Client, r *Request) error { - if l := len(c.PathParams()) + len(c.RawPathParams()) + len(r.PathParams) + len(r.RawPathParams); l > 0 { - params := make(map[string]string, l) - - // GitHub #103 Path Params - for p, v := range r.PathParams { - params[p] = url.PathEscape(v) - } + if len(c.PathParams())+len(r.PathParams) > 0 { + // GitHub #103 Path Params, #663 Raw Path Params for p, v := range c.PathParams() { - if _, ok := params[p]; !ok { - params[p] = url.PathEscape(v) + if _, ok := r.PathParams[p]; ok { + continue } + r.PathParams[p] = v } - // GitHub #663 Raw Path Params - for p, v := range r.RawPathParams { - if _, ok := params[p]; !ok { - params[p] = v + var prev int + buf := acquireBuffer() + defer releaseBuffer(buf) + // search for the next or first opened curly bracket + for curr := strings.Index(r.URL, "{"); curr == 0 || curr > prev; curr = prev + strings.Index(r.URL[prev:], "{") { + // write everything from the previous position up to the current + if curr > prev { + buf.WriteString(r.URL[prev:curr]) } - } - for p, v := range c.RawPathParams() { - if _, ok := params[p]; !ok { - params[p] = v + // search for the closed curly bracket from current position + next := curr + strings.Index(r.URL[curr:], "}") + // if not found, then write the remainder and exit + if next < curr { + buf.WriteString(r.URL[curr:]) + prev = len(r.URL) + break } - } - - if len(params) > 0 { - var prev int - buf := acquireBuffer() - defer releaseBuffer(buf) - // search for the next or first opened curly bracket - for curr := strings.Index(r.URL, "{"); curr == 0 || curr > prev; curr = prev + strings.Index(r.URL[prev:], "{") { - // write everything from the previous position up to the current - if curr > prev { - buf.WriteString(r.URL[prev:curr]) - } - // search for the closed curly bracket from current position - next := curr + strings.Index(r.URL[curr:], "}") - // if not found, then write the remainder and exit - if next < curr { - buf.WriteString(r.URL[curr:]) - prev = len(r.URL) - break - } - // special case for {}, without parameter's name - if next == curr+1 { - buf.WriteString("{}") - } else { - // check for the replacement - key := r.URL[curr+1 : next] - value, ok := params[key] - /// keep the original string if the replacement not found - if !ok { - value = r.URL[curr : next+1] - } - buf.WriteString(value) + // special case for {}, without parameter's name + if next == curr+1 { + buf.WriteString("{}") + } else { + // check for the replacement + key := r.URL[curr+1 : next] + value, ok := r.PathParams[key] + // keep the original string if the replacement not found + if !ok { + value = r.URL[curr : next+1] } + buf.WriteString(value) + } - // set the previous position after the closed curly bracket - prev = next + 1 - if prev >= len(r.URL) { - break - } + // set the previous position after the closed curly bracket + prev = next + 1 + if prev >= len(r.URL) { + break } - if buf.Len() > 0 { - // write remainder - if prev < len(r.URL) { - buf.WriteString(r.URL[prev:]) - } - r.URL = buf.String() + } + if buf.Len() > 0 { + // write remainder + if prev < len(r.URL) { + buf.WriteString(r.URL[prev:]) } + r.URL = buf.String() } } @@ -173,11 +154,9 @@ func parseRequestURL(c *Client, r *Request) error { // Adding Query Param if len(c.QueryParams())+len(r.QueryParams) > 0 { for k, v := range c.QueryParams() { - // skip query parameter if it was set in request if _, ok := r.QueryParams[k]; ok { continue } - r.QueryParams[k] = v[:] } @@ -185,12 +164,10 @@ func parseRequestURL(c *Client, r *Request) error { // Since not feasible in `SetQuery*` resty methods, because // standard package `url.Encode(...)` sorts the query params // alphabetically - if len(r.QueryParams) > 0 { - if isStringEmpty(reqURL.RawQuery) { - reqURL.RawQuery = r.QueryParams.Encode() - } else { - reqURL.RawQuery = reqURL.RawQuery + "&" + r.QueryParams.Encode() - } + if isStringEmpty(reqURL.RawQuery) { + reqURL.RawQuery = r.QueryParams.Encode() + } else { + reqURL.RawQuery = reqURL.RawQuery + "&" + r.QueryParams.Encode() } } @@ -221,7 +198,7 @@ func parseRequestHeader(c *Client, r *Request) error { } if !r.isHeaderExists(hdrAcceptEncodingKey) { - r.Header.Set(hdrAcceptEncodingKey, r.client.ContentDecompressorKeys()) + r.Header.Set(hdrAcceptEncodingKey, r.client.ContentDecompresserKeys()) } return nil @@ -250,19 +227,21 @@ func parseRequestBody(c *Client, r *Request) error { r.Body = nil // if the payload is not supported by HTTP verb, set explicit nil } - // by default resty won't set content length, you can if you want to :) + // by default resty won't set content length, but user can opt-in if r.setContentLength { - if r.bodyBuf == nil && r.Body == nil { - r.Header.Set(hdrContentLengthKey, "0") - } else if r.bodyBuf != nil { - r.Header.Set(hdrContentLengthKey, strconv.Itoa(r.bodyBuf.Len())) + cntLen := 0 + if r.bodyBuf != nil { + cntLen = r.bodyBuf.Len() + } else if b, ok := r.Body.(*bytes.Reader); ok { + cntLen = b.Len() } + r.Header.Set(hdrContentLengthKey, strconv.Itoa(cntLen)) } return nil } -func createHTTPRequest(c *Client, r *Request) (err error) { +func createRawRequest(c *Client, r *Request) (err error) { // init client trace if enabled r.initTraceIfEnabled() @@ -381,6 +360,10 @@ func handleMultipart(c *Client, r *Request) error { return nil } +var mpCreatePart = func(w *multipart.Writer, h textproto.MIMEHeader) (io.Writer, error) { + return w.CreatePart(h) +} + func createMultipart(w *multipart.Writer, r *Request) error { if err := r.writeFormData(w); err != nil { return err @@ -408,18 +391,15 @@ func createMultipart(w *multipart.Writer, r *Request) error { mf.ContentType = http.DetectContentType(p[:size]) } - partWriter, err := w.CreatePart(mf.createHeader()) + partWriter, err := mpCreatePart(w, mf.createHeader()) if err != nil { return err } partWriter = mf.wrapProgressCallbackIfPresent(partWriter) + partWriter.Write(p[:size]) - if _, err = partWriter.Write(p[:size]); err != nil { - return err - } - _, err = ioCopy(partWriter, mf.Reader) - if err != nil { + if _, err = ioCopy(partWriter, mf.Reader); err != nil { return err } } @@ -460,13 +440,10 @@ func handleRequestBody(c *Client, r *Request) error { releaseBuffer(r.bodyBuf) r.bodyBuf = nil - // enable multiple reads if retry enabled - // and body type is *bytes.Buffer - if r.RetryCount > 0 { - if b, ok := r.Body.(*bytes.Buffer); ok { - v := b.Bytes() - r.Body = bytes.NewReader(v) - } + // enable multiple reads if body is *bytes.Buffer + if b, ok := r.Body.(*bytes.Buffer); ok { + v := b.Bytes() + r.Body = bytes.NewReader(v) } // do seek start for retry attempt if io.ReadSeeker diff --git a/middleware_test.go b/middleware_test.go index ac20ac32..d11ba215 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -114,12 +114,12 @@ func Test_parseRequestURL(t *testing.T) { r.SetPathParams(map[string]string{ "foo": "4/5", }).SetRawPathParams(map[string]string{ - "foo": "4/5", // ignored, because the pathParams takes precedence over the rawPathParams + "foo": "4/5", // it gets overwritten since same key name "bar": "6/7", }) r.URL = "https://example.com/{foo}/{bar}" }, - expectedURL: "https://example.com/4%2F5/6/7", + expectedURL: "https://example.com/4/5/6/7", }, { name: "empty path parameter in URL", @@ -351,48 +351,6 @@ func Test_parseRequestURL(t *testing.T) { } } -func Benchmark_parseRequestURL_PathParams(b *testing.B) { - c := New().SetPathParams(map[string]string{ - "foo": "1", - "bar": "2", - }).SetRawPathParams(map[string]string{ - "foo": "3", - "xyz": "4", - }) - r := c.R().SetPathParams(map[string]string{ - "foo": "5", - "qwe": "6", - }).SetRawPathParams(map[string]string{ - "foo": "7", - "asd": "8", - }) - b.ResetTimer() - for i := 0; i < b.N; i++ { - r.URL = "https://example.com/{foo}/{bar}/{xyz}/{qwe}/{asd}" - if err := parseRequestURL(c, r); err != nil { - b.Errorf("parseRequestURL() error = %v", err) - } - } -} - -func Benchmark_parseRequestURL_QueryParams(b *testing.B) { - c := New().SetQueryParams(map[string]string{ - "foo": "1", - "bar": "2", - }) - r := c.R().SetQueryParams(map[string]string{ - "foo": "5", - "qwe": "6", - }) - b.ResetTimer() - for i := 0; i < b.N; i++ { - r.URL = "https://example.com/" - if err := parseRequestURL(c, r); err != nil { - b.Errorf("parseRequestURL() error = %v", err) - } - } -} - func Test_parseRequestHeader(t *testing.T) { for _, tt := range []struct { name string @@ -479,7 +437,7 @@ func Test_parseRequestHeader(t *testing.T) { tt.init(c, r) // add common expected headers from client into expectedHeader - tt.expectedHeader.Set(hdrAcceptEncodingKey, c.ContentDecompressorKeys()) + tt.expectedHeader.Set(hdrAcceptEncodingKey, c.ContentDecompresserKeys()) if err := parseRequestHeader(c, r); err != nil { t.Errorf("parseRequestHeader() error = %v", err) @@ -491,25 +449,6 @@ func Test_parseRequestHeader(t *testing.T) { } } -func Benchmark_parseRequestHeader(b *testing.B) { - c := New() - r := c.R() - c.SetHeaders(map[string]string{ - "foo": "1", // ignored, because of the same header in the request - "bar": "2", - }) - r.SetHeaders(map[string]string{ - "foo": "3", - "xyz": "4", - }) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestHeader(c, r); err != nil { - b.Errorf("parseRequestHeader() error = %v", err) - } - } -} - func TestParseRequestBody(t *testing.T) { for _, tt := range []struct { name string @@ -920,149 +859,6 @@ func TestParseRequestBody(t *testing.T) { } } -func Benchmark_parseRequestBody_string(b *testing.B) { - c := New() - r := c.R() - r.SetBody("foo").SetContentLength(true) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - -func Benchmark_parseRequestBody_byte(b *testing.B) { - c := New() - r := c.R() - r.SetBody([]byte("foo")).SetContentLength(true) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - -func Benchmark_parseRequestBody_reader_with_SetContentLength(b *testing.B) { - c := New() - r := c.R() - r.SetBody(bytes.NewBufferString("foo")).SetContentLength(true) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - -func Benchmark_parseRequestBody_reader_without_SetContentLength(b *testing.B) { - c := New() - r := c.R() - r.SetBody(bytes.NewBufferString("foo")) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - -func Benchmark_parseRequestBody_struct(b *testing.B) { - type FooBar struct { - Foo string `json:"foo"` - Bar string `json:"bar"` - } - c := New() - r := c.R() - r.SetBody(FooBar{Foo: "1", Bar: "2"}).SetContentLength(true).SetHeader(hdrContentTypeKey, jsonContentType) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - -func Benchmark_parseRequestBody_struct_xml(b *testing.B) { - type FooBar struct { - Foo string `xml:"foo"` - Bar string `xml:"bar"` - } - c := New() - r := c.R() - r.SetBody(FooBar{Foo: "1", Bar: "2"}).SetContentLength(true).SetHeader(hdrContentTypeKey, "text/xml") - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - -func Benchmark_parseRequestBody_map(b *testing.B) { - c := New() - r := c.R() - r.SetBody(map[string]string{ - "foo": "1", - "bar": "2", - }).SetContentLength(true).SetHeader(hdrContentTypeKey, jsonContentType) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - -func Benchmark_parseRequestBody_slice(b *testing.B) { - c := New() - r := c.R() - r.SetBody([]string{"1", "2"}).SetContentLength(true).SetHeader(hdrContentTypeKey, jsonContentType) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - -func Benchmark_parseRequestBody_FormData(b *testing.B) { - c := New() - r := c.R() - c.SetFormData(map[string]string{"foo": "1", "bar": "2"}) - r.SetFormData(map[string]string{"foo": "3", "baz": "4"}).SetContentLength(true) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - -func Benchmark_parseRequestBody_MultiPart(b *testing.B) { - c := New() - r := c.R() - c.SetFormData(map[string]string{"foo": "1", "bar": "2"}) - r.SetFormData(map[string]string{"foo": "3", "baz": "4"}). - SetMultipartFormData(map[string]string{"foo": "5", "xyz": "6"}). - SetFileReader("qwe", "qwe.txt", strings.NewReader("7")). - SetMultipartFields( - &MultipartField{ - Name: "sdj", - ContentType: "text/plain", - Reader: strings.NewReader("8"), - }, - ). - SetContentLength(true) - b.ResetTimer() - for i := 0; i < b.N; i++ { - if err := parseRequestBody(c, r); err != nil { - b.Errorf("parseRequestBody() error = %v", err) - } - } -} - func TestMiddlewareSaveToFileErrorCases(t *testing.T) { c := dcnl() tempDir := t.TempDir() @@ -1137,6 +933,6 @@ func TestMiddlewareCoverage(t *testing.T) { req1 := c.R() req1.URL = "//invalid-url .local" - err1 := createHTTPRequest(c, req1) + err1 := createRawRequest(c, req1) assertEqual(t, true, strings.Contains(err1.Error(), "invalid character")) } diff --git a/multipart_test.go b/multipart_test.go index d42d596b..f84a213c 100644 --- a/multipart_test.go +++ b/multipart_test.go @@ -13,6 +13,7 @@ import ( "io/fs" "mime/multipart" "net/http" + "net/textproto" "net/url" "os" "path/filepath" @@ -213,6 +214,39 @@ func TestMultipartFormData(t *testing.T) { assertEqual(t, "Success", resp.String()) } +func TestMultipartFormDataFields(t *testing.T) { + ts := createFormPostServer(t) + defer ts.Close() + + fields := []*MultipartField{ + { + Name: "field1", + Values: []string{"field1value1", "field1value2"}, + }, + { + Name: "field1", + Values: []string{"field1value3", "field1value4"}, + }, + { + Name: "field3", + Values: []string{"field3value1", "field3value2"}, + }, + { + Name: "field4", + Values: []string{"field4value1", "field4value2"}, + }, + } + + resp, err := dcnldr(). + SetMultipartFields(fields...). + SetBasicAuth("myuser", "mypass"). + Post(ts.URL + "/profile") + + assertError(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) + assertEqual(t, "Success", resp.String()) +} + func TestMultipartField(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() @@ -582,7 +616,28 @@ func TestMultipartRequest_createMultipart(t *testing.T) { mw := multipart.NewWriter(new(bytes.Buffer)) err := createMultipart(mw, req1) assertNotNil(t, err) - assertEqual(t, "test copy error", err.Error()) + assertEqual(t, errCopyMsg, err.Error()) + }) + + t.Run("multipart create part error", func(t *testing.T) { + errMsg := "test create part error" + mpCreatePart = func(w *multipart.Writer, h textproto.MIMEHeader) (io.Writer, error) { + return nil, errors.New(errMsg) + } + t.Cleanup(func() { + mpCreatePart = func(w *multipart.Writer, h textproto.MIMEHeader) (io.Writer, error) { + return w.CreatePart(h) + } + }) + + req1 := c.R(). + SetFile("file", filepath.Join(getTestDataPath(), "test-img.png")). + SetHeader("Content-Type", "image/png") + + mw := multipart.NewWriter(new(bytes.Buffer)) + err := createMultipart(mw, req1) + assertNotNil(t, err) + assertEqual(t, errMsg, err.Error()) }) } diff --git a/request.go b/request.go index b490da50..4e8cfe15 100644 --- a/request.go +++ b/request.go @@ -40,7 +40,6 @@ type Request struct { QueryParams url.Values FormData url.Values PathParams map[string]string - RawPathParams map[string]string Header http.Header Time time.Time Body any @@ -110,9 +109,9 @@ func (r *Request) GenerateCurlCommand() string { return r.resultCurlCmd } if r.RawRequest == nil { - // mock with r.Get("/") if err := r.client.executeRequestMiddlewares(r); err != nil { r.log.Errorf("%v", err) + return "" } } r.resultCurlCmd = buildCurlCmd(r) @@ -756,7 +755,7 @@ func (r *Request) SetResponseBodyUnlimitedReads(b bool) *Request { // // It overrides the path parameter set at the client instance level. func (r *Request) SetPathParam(param, value string) *Request { - r.PathParams[param] = value + r.PathParams[param] = url.PathEscape(value) return r } @@ -785,7 +784,7 @@ func (r *Request) SetPathParams(params map[string]string) *Request { } // SetRawPathParam method sets a single URL path key-value pair in the -// Resty current request instance. +// Resty current request instance without path escape. // // client.R().SetPathParam("userId", "sample@sample.com") // @@ -800,16 +799,16 @@ func (r *Request) SetPathParams(params map[string]string) *Request { // Composed URL - /v1/users/groups/developers/details // // It replaces the value of the key while composing the request URL. -// The value will be used as-is and has not been escaped. +// The value will be used as-is, no path escape applied. // // It overrides the raw path parameter set at the client instance level. func (r *Request) SetRawPathParam(param, value string) *Request { - r.RawPathParams[param] = value + r.PathParams[param] = value return r } // SetRawPathParams method sets multiple URL path key-value pairs at one go in the -// Resty current request instance. +// Resty current request instance without path escape. // // client.R().SetPathParams(map[string]string{ // "userId": "sample@sample.com", @@ -822,7 +821,7 @@ func (r *Request) SetRawPathParam(param, value string) *Request { // Composed URL - /v1/users/sample@sample.com/100002/groups/developers/details // // It replaces the value of the key while composing the request URL. -// The value will be used as-is and has not been escaped. +// The value will be used as-is, no path escape applied. // // It overrides the raw path parameter set at the client instance level. func (r *Request) SetRawPathParams(params map[string]string) *Request { @@ -1412,7 +1411,6 @@ func (r *Request) Clone(ctx context.Context) *Request { rr.FormData = cloneURLValues(r.FormData) rr.QueryParams = cloneURLValues(r.QueryParams) rr.PathParams = maps.Clone(r.PathParams) - rr.RawPathParams = maps.Clone(r.RawPathParams) // clone basic auth if r.credentials != nil { @@ -1450,7 +1448,7 @@ func (r *Request) Clone(ctx context.Context) *Request { // copy bodyBuf if r.bodyBuf != nil { rr.bodyBuf = acquireBuffer() - _, _ = io.Copy(rr.bodyBuf, r.bodyBuf) + rr.bodyBuf.Write(r.bodyBuf.Bytes()) } return rr diff --git a/request_test.go b/request_test.go index 3f0202d0..60e8e2c2 100644 --- a/request_test.go +++ b/request_test.go @@ -1655,8 +1655,8 @@ func TestRawPathParamURLInput(t *testing.T) { "path": "users/developers", }) - assertEqual(t, "sample@sample.com", c.RawPathParams()["userId"]) - assertEqual(t, "users/developers", c.RawPathParams()["path"]) + assertEqual(t, "sample@sample.com", c.PathParams()["userId"]) + assertEqual(t, "users/developers", c.PathParams()["path"]) resp, err := c.R().EnableDebug(). SetRawPathParams(map[string]string{ @@ -1934,7 +1934,6 @@ func TestRequestClone(t *testing.T) { // update value of non-interface type - change will only happen on clone clone.URL = "http://localhost.clone" clone.PathParams["name"] = "clone" - clone.RawPathParams["name"] = "clone" // update value of http header - change will only happen on clone clone.SetHeader("X-Header", "clone") // update value of interface type - change will only happen on clone @@ -1947,8 +1946,6 @@ func TestRequestClone(t *testing.T) { assertEqual(t, ts.URL, parent.URL) assertEqual(t, "clone", clone.PathParams["name"]) assertEqual(t, "parent", parent.PathParams["name"]) - assertEqual(t, "clone", clone.RawPathParams["name"]) - assertEqual(t, "parent", parent.RawPathParams["name"]) // assert http header assertEqual(t, "parent", parent.Header.Get("X-Header")) assertEqual(t, "clone", clone.Header.Get("X-Header")) @@ -2193,6 +2190,34 @@ func TestRequestSetResultAndSetOutputFile(t *testing.T) { assertEqual(t, `{ "id": "success", "message": "login successful" }`, string(fileContent)) } +func TestRequestBodyContentLength(t *testing.T) { + ts := createGenericServer(t) + defer ts.Close() + + c := dcnl().SetBaseURL(ts.URL) + + c.SetRequestMiddlewares( + PrepareRequestMiddleware, + func(c *Client, r *Request) error { + _, found := r.Header[hdrContentLengthKey] + assertEqual(t, true, found) + return nil + }, + ) + + buf := bytes.NewBuffer([]byte(`{"content":"json content sending to server"}`)) + res, err := c.R(). + SetHeader(hdrContentTypeKey, "application/json"). + SetContentLength(true). + SetBody(buf). + Put("/json") + + assertError(t, err) + assertEqual(t, http.StatusOK, res.StatusCode()) + assertEqual(t, `{"response":"json response"}`, res.String()) + assertEqual(t, int64(44), res.Request.RawRequest.ContentLength) +} + func TestRequestFuncs(t *testing.T) { ts := createGetServer(t) defer ts.Close() diff --git a/response.go b/response.go index 231ed4d2..691f4e91 100644 --- a/response.go +++ b/response.go @@ -247,13 +247,13 @@ func (r *Response) wrapCopyReadCloser() { } } -func (r *Response) wrapContentDecompressor() error { +func (r *Response) wrapContentDecompresser() error { ce := r.Header().Get(hdrContentEncodingKey) if isStringEmpty(ce) { return nil } - if decFunc, f := r.Request.client.ContentDecompressors()[ce]; f { + if decFunc, f := r.Request.client.ContentDecompressers()[ce]; f { dec, err := decFunc(r.Body) if err != nil { if err == io.EOF { @@ -268,7 +268,7 @@ func (r *Response) wrapContentDecompressor() error { r.Header().Del(hdrContentLengthKey) r.RawResponse.ContentLength = -1 } else { - return ErrContentDecompressorNotFound + return ErrContentDecompresserNotFound } return nil diff --git a/resty.go b/resty.go index 3a2ce07a..f624a260 100644 --- a/resty.go +++ b/resty.go @@ -169,15 +169,14 @@ func createClient(hc *http.Client) *Client { retryMaxWaitTime: defaultMaxWaitTime, isRetryDefaultConditions: true, pathParams: make(map[string]string), - rawPathParams: make(map[string]string), headerAuthorizationKey: hdrAuthorizationKey, jsonEscapeHTML: true, httpClient: hc, debugBodyLimit: math.MaxInt32, contentTypeEncoders: make(map[string]ContentTypeEncoder), contentTypeDecoders: make(map[string]ContentTypeDecoder), - contentDecompressorKeys: make([]string, 0), - contentDecompressors: make(map[string]ContentDecompressor), + contentDecompresserKeys: make([]string, 0), + contentDecompressers: make(map[string]ContentDecompresser), certWatcherStopChan: make(chan bool), } @@ -191,8 +190,8 @@ func createClient(hc *http.Client) *Client { c.AddContentTypeDecoder(xmlKey, decodeXML) // Order matter, giving priority to gzip - c.AddContentDecompressor("deflate", decompressDeflate) - c.AddContentDecompressor("gzip", decompressGzip) + c.AddContentDecompresser("deflate", decompressDeflate) + c.AddContentDecompresser("gzip", decompressGzip) // request middlewares c.SetRequestMiddlewares( diff --git a/resty_test.go b/resty_test.go index a1fd2248..66cd2c10 100644 --- a/resty_test.go +++ b/resty_test.go @@ -34,6 +34,7 @@ import ( var ( hdrLocationKey = http.CanonicalHeaderKey("Location") + hdrAcceptKey = http.CanonicalHeaderKey("Accept") ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ @@ -268,11 +269,6 @@ func handleUsersEndpoint(t *testing.T, w http.ResponseWriter, r *http.Request) { func createPostServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { - t.Logf("Method: %v", r.Method) - t.Logf("Path: %v", r.URL.Path) - t.Logf("RawQuery: %v", r.URL.RawQuery) - t.Logf("Content-Type: %v", r.Header.Get(hdrContentTypeKey)) - if r.Method == MethodPost { handleLoginEndpoint(t, w, r) @@ -356,18 +352,17 @@ func createPostServer(t *testing.T) *httptest.Server { func createFormPostServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { - t.Logf("Method: %v", r.Method) - t.Logf("Path: %v", r.URL.Path) - t.Logf("Content-Type: %v", r.Header.Get(hdrContentTypeKey)) - if r.Method == MethodPost { _ = r.ParseMultipartForm(10e6) if r.URL.Path == "/profile" { - t.Logf("FirstName: %v", r.FormValue("first_name")) - t.Logf("LastName: %v", r.FormValue("last_name")) - t.Logf("City: %v", r.FormValue("city")) - t.Logf("Zip Code: %v", r.FormValue("zip_code")) + if r.MultipartForm == nil { + values := r.Form + t.Log(values) + } else { + values := r.MultipartForm.Value + t.Log(values) + } _, _ = w.Write([]byte("Success")) return diff --git a/stream.go b/stream.go index b05b40ef..14c0f885 100644 --- a/stream.go +++ b/stream.go @@ -16,7 +16,7 @@ import ( ) var ( - ErrContentDecompressorNotFound = errors.New("resty: content decoder not found") + ErrContentDecompresserNotFound = errors.New("resty: content decoder not found") ) type ( @@ -26,13 +26,13 @@ type ( // ContentTypeDecoder type is for decoding the response body based on header Content-Type ContentTypeDecoder func(io.Reader, any) error - // ContentDecompressor type is for decompressing response body based on header Content-Encoding + // ContentDecompresser type is for decompressing response body based on header Content-Encoding // ([RFC 9110]) // // For example, gzip, deflate, etc. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110 - ContentDecompressor func(io.ReadCloser) (io.ReadCloser, error) + ContentDecompresser func(io.ReadCloser) (io.ReadCloser, error) ) func encodeJSON(w io.Writer, v any) error { @@ -203,7 +203,7 @@ type nopReadCloser struct { func (r *nopReadCloser) Read(p []byte) (int, error) { n, err := r.r.Read(p) if err == io.EOF { - r.r.Seek(0, 0) + r.r.Seek(0, io.SeekStart) } return n, err }