Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support http.RoundTripper #91

Merged
merged 4 commits into from
Aug 18, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 57 additions & 21 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"crypto/x509"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -91,7 +92,6 @@ type Client struct {
JSONUnmarshal func(data []byte, v interface{}) error

httpClient *http.Client
transport *http.Transport
setContentLength bool
isHTTPMode bool
outputDirectory string
Expand Down Expand Up @@ -555,8 +555,12 @@ func (c *Client) Mode() string {
// Note: This method overwrites existing `TLSClientConfig`.
//
func (c *Client) SetTLSClientConfig(config *tls.Config) *Client {
c.transport.TLSClientConfig = config
c.httpClient.Transport = c.transport
transport, err := c.getTransport()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
transport.TLSClientConfig = config
return c
}

Expand All @@ -567,10 +571,14 @@ func (c *Client) SetTLSClientConfig(config *tls.Config) *Client {
// you can also set Proxy via environment variable. By default `Go` uses setting from `HTTP_PROXY`.
//
func (c *Client) SetProxy(proxyURL string) *Client {
transport, err := c.getTransport()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
if pURL, err := url.Parse(proxyURL); err == nil {
c.proxyURL = pURL
c.transport.Proxy = http.ProxyURL(c.proxyURL)
c.httpClient.Transport = c.transport
transport.Proxy = http.ProxyURL(c.proxyURL)
} else {
c.Log.Printf("ERROR [%v]", err)
c.RemoveProxy()
Expand All @@ -583,17 +591,24 @@ func (c *Client) SetProxy(proxyURL string) *Client {
// resty.RemoveProxy()
//
func (c *Client) RemoveProxy() *Client {
transport, err := c.getTransport()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
c.proxyURL = nil
c.transport.Proxy = nil
c.httpClient.Transport = c.transport

transport.Proxy = nil
return c
}

// SetCertificates method helps to set client certificates into resty conveniently.
//
func (c *Client) SetCertificates(certs ...tls.Certificate) *Client {
config := c.getTLSConfig()
config, err := c.getTLSConfig()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
config.Certificates = append(config.Certificates, certs...)
return c
}
Expand All @@ -608,7 +623,11 @@ func (c *Client) SetRootCertificate(pemFilePath string) *Client {
return c
}

config := c.getTLSConfig()
config, err := c.getTLSConfig()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
if config.RootCAs == nil {
config.RootCAs = x509.NewCertPool()
}
Expand All @@ -634,9 +653,16 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client {
return c
}

// SetTransport method sets custom *http.Transport in the resty client. Its way to override default.
// SetTransport method sets custom `*http.Transport` or any `http.RoundTripper`
// compatible interface implementation in the resty client.
//
// Please Note:
//
// - If transport is not type of `*http.Transport` then you may not be able to
// take advantage of some of the `resty` client settings.
//
// - It overwrites the resty client transport instance and it's configurations.
//
// **Note:** It overwrites the default resty transport instance and its configurations.
// transport := &http.Transport{
// // somthing like Proxying to httptest.Server, etc...
// Proxy: func(req *http.Request) (*url.URL, error) {
Expand All @@ -646,12 +672,10 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client {
//
// resty.SetTransport(transport)
//
func (c *Client) SetTransport(transport *http.Transport) *Client {
func (c *Client) SetTransport(transport http.RoundTripper) *Client {
if transport != nil {
c.transport = transport
c.httpClient.Transport = c.transport
c.httpClient.Transport = transport
}

return c
}

Expand Down Expand Up @@ -769,12 +793,24 @@ func (c *Client) disableLogPrefix() {
}

// getting TLS client config if not exists then create one
func (c *Client) getTLSConfig() *tls.Config {
if c.transport.TLSClientConfig == nil {
c.transport.TLSClientConfig = &tls.Config{}
c.httpClient.Transport = c.transport
func (c *Client) getTLSConfig() (*tls.Config, error) {
transport, err := c.getTransport()
if err != nil {
return nil, err
}
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
return transport.TLSClientConfig, nil
}

// returns `*http.Transport` currently in use or error
// in case currently used `transport` is not an `*http.Transport`
func (c *Client) getTransport() (*http.Transport, error) {
if transport, ok := c.httpClient.Transport.(*http.Transport); ok {
return transport, nil
}
return c.transport.TLSClientConfig
return nil, errors.New("current transport is not an *http.Transport instance")
}

//
Expand Down
81 changes: 62 additions & 19 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,40 +136,49 @@ func TestClientProxy(t *testing.T) {
c.SetProxy("http://sampleproxy:8888")

resp, err := c.R().Get(ts.URL)
assertEqual(t, true, resp != nil)
assertEqual(t, true, err != nil)
assertNotNil(t, resp)
assertNotNil(t, err)

// Error
c.SetProxy("//not.a.user@%66%6f%6f.com:8888")

resp, err = c.R().
Get(ts.URL)
assertEqual(t, true, err == nil)
assertEqual(t, false, resp == nil)
assertNil(t, err)
assertNotNil(t, resp)
}

func TestSetCertificates(t *testing.T) {
func TestClientSetCertificates(t *testing.T) {
DefaultClient = dc()
SetCertificates(tls.Certificate{})

assertEqual(t, 1, len(DefaultClient.transport.TLSClientConfig.Certificates))
transport, err := DefaultClient.getTransport()

assertNil(t, err)
assertEqual(t, 1, len(transport.TLSClientConfig.Certificates))
}

func TestSetRootCertificate(t *testing.T) {
func TestClientSetRootCertificate(t *testing.T) {
DefaultClient = dc()
SetRootCertificate(getTestDataPath() + "/sample-root.pem")

assertEqual(t, true, DefaultClient.transport.TLSClientConfig.RootCAs != nil)
transport, err := DefaultClient.getTransport()

assertNil(t, err)
assertNotNil(t, transport.TLSClientConfig.RootCAs)
}

func TestSetRootCertificateNotExists(t *testing.T) {
func TestClientSetRootCertificateNotExists(t *testing.T) {
DefaultClient = dc()
SetRootCertificate(getTestDataPath() + "/not-exists-sample-root.pem")

assertEqual(t, true, DefaultClient.transport.TLSClientConfig == nil)
transport, err := DefaultClient.getTransport()

assertNil(t, err)
assertNil(t, transport.TLSClientConfig)
}

func TestOnBeforeRequestModification(t *testing.T) {
func TestClientOnBeforeRequestModification(t *testing.T) {
tc := New()
tc.OnBeforeRequest(func(c *Client, r *Request) error {
r.SetAuthToken("This is test auth token")
Expand All @@ -184,42 +193,46 @@ func TestOnBeforeRequestModification(t *testing.T) {
assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
assertEqual(t, "200 OK", resp.Status())
assertEqual(t, true, resp.Body() != nil)
assertNotNil(t, resp.Body())
assertEqual(t, "TestGet: text response", resp.String())

logResponse(t, resp)
}

func TestSetTransport(t *testing.T) {
func TestClientSetTransport(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()
DefaultClient = dc()

transport := &http.Transport{
// somthing like Proxying to httptest.Server, etc...
// something like Proxying to httptest.Server, etc...
Proxy: func(req *http.Request) (*url.URL, error) {
return url.Parse(ts.URL)
},
}
SetTransport(transport)

assertEqual(t, true, DefaultClient.transport != nil)
transportInUse, err := DefaultClient.getTransport()

assertNil(t, err)

assertEqual(t, true, transport == transportInUse)
}

func TestSetScheme(t *testing.T) {
func TestClientSetScheme(t *testing.T) {
DefaultClient = dc()

SetScheme("http")

assertEqual(t, true, DefaultClient.scheme == "http")
}

func TestSetCookieJar(t *testing.T) {
func TestClientSetCookieJar(t *testing.T) {
DefaultClient = dc()
backupJar := DefaultClient.httpClient.Jar

SetCookieJar(nil)
assertEqual(t, true, DefaultClient.httpClient.Jar == nil)
assertNil(t, DefaultClient.httpClient.Jar)

SetCookieJar(backupJar)
assertEqual(t, true, DefaultClient.httpClient.Jar == backupJar)
Expand Down Expand Up @@ -304,7 +317,10 @@ func TestClientOptions(t *testing.T) {
}

SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
assertEqual(t, true, DefaultClient.transport.TLSClientConfig.InsecureSkipVerify)
transport, transportErr := DefaultClient.getTransport()

assertNil(t, transportErr)
assertEqual(t, true, transport.TLSClientConfig.InsecureSkipVerify)

OnBeforeRequest(func(c *Client, r *Request) error {
c.Log.Println("I'm in Request middleware")
Expand Down Expand Up @@ -363,3 +379,30 @@ func TestClientAllowsGetMethodPayload(t *testing.T) {
assertEqual(t, http.StatusOK, resp.StatusCode())
assertEqual(t, payload, resp.String())
}

func TestClientRoundTripper(t *testing.T) {
c := New()

rt := &CustomRoundTripper{}
c.SetTransport(rt)

ct, err := c.getTransport()
assertNotNil(t, err)
assertNil(t, ct)
assertEqual(t, "current transport is not an *http.Transport instance", err.Error())

c.SetTLSClientConfig(&tls.Config{})
c.SetProxy("http://localhost:9090")
c.RemoveProxy()
c.SetCertificates(tls.Certificate{})
c.SetRootCertificate(getTestDataPath() + "/sample-root.pem")
}

// CustomRoundTripper just for test
type CustomRoundTripper struct {
}

// RoundTrip just for test
func (rt *CustomRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
return &http.Response{}, nil
}
9 changes: 5 additions & 4 deletions default.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ func New() *Client {
JSONMarshal: json.Marshal,
JSONUnmarshal: json.Unmarshal,
httpClient: &http.Client{Jar: cookieJar},
transport: &http.Transport{},
}

c.httpClient.Transport = c.transport
// Default transport
c.SetTransport(&http.Transport{})

// Default redirect policy
c.SetRedirectPolicy(NoRedirectPolicy())
Expand Down Expand Up @@ -254,9 +254,10 @@ func SetOutputDirectory(dirPath string) *Client {
return DefaultClient.SetOutputDirectory(dirPath)
}

// SetTransport method sets custom *http.Transport in the resty client.
// SetTransport method sets custom `*http.Transport` or any `http.RoundTripper`
// compatible interface implementation in the resty client.
// See `Client.SetTransport` for more information.
func SetTransport(transport *http.Transport) *Client {
func SetTransport(transport http.RoundTripper) *Client {
return DefaultClient.SetTransport(transport)
}

Expand Down
Loading