Skip to content

Commit

Permalink
Added support http.RoundTripper (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeevatkm authored Aug 18, 2017
1 parent 07d07bf commit 8d8edd7
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 77 deletions.
78 changes: 57 additions & 21 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"crypto/x509"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -91,7 +92,6 @@ type Client struct {
JSONUnmarshal func(data []byte, v interface{}) error

httpClient *http.Client
transport *http.Transport
setContentLength bool
isHTTPMode bool
outputDirectory string
Expand Down Expand Up @@ -555,8 +555,12 @@ func (c *Client) Mode() string {
// Note: This method overwrites existing `TLSClientConfig`.
//
func (c *Client) SetTLSClientConfig(config *tls.Config) *Client {
c.transport.TLSClientConfig = config
c.httpClient.Transport = c.transport
transport, err := c.getTransport()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
transport.TLSClientConfig = config
return c
}

Expand All @@ -567,10 +571,14 @@ func (c *Client) SetTLSClientConfig(config *tls.Config) *Client {
// you can also set Proxy via environment variable. By default `Go` uses setting from `HTTP_PROXY`.
//
func (c *Client) SetProxy(proxyURL string) *Client {
transport, err := c.getTransport()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
if pURL, err := url.Parse(proxyURL); err == nil {
c.proxyURL = pURL
c.transport.Proxy = http.ProxyURL(c.proxyURL)
c.httpClient.Transport = c.transport
transport.Proxy = http.ProxyURL(c.proxyURL)
} else {
c.Log.Printf("ERROR [%v]", err)
c.RemoveProxy()
Expand All @@ -583,17 +591,24 @@ func (c *Client) SetProxy(proxyURL string) *Client {
// resty.RemoveProxy()
//
func (c *Client) RemoveProxy() *Client {
transport, err := c.getTransport()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
c.proxyURL = nil
c.transport.Proxy = nil
c.httpClient.Transport = c.transport

transport.Proxy = nil
return c
}

// SetCertificates method helps to set client certificates into resty conveniently.
//
func (c *Client) SetCertificates(certs ...tls.Certificate) *Client {
config := c.getTLSConfig()
config, err := c.getTLSConfig()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
config.Certificates = append(config.Certificates, certs...)
return c
}
Expand All @@ -608,7 +623,11 @@ func (c *Client) SetRootCertificate(pemFilePath string) *Client {
return c
}

config := c.getTLSConfig()
config, err := c.getTLSConfig()
if err != nil {
c.Log.Printf("ERROR [%v]", err)
return c
}
if config.RootCAs == nil {
config.RootCAs = x509.NewCertPool()
}
Expand All @@ -634,9 +653,16 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client {
return c
}

// SetTransport method sets custom *http.Transport in the resty client. Its way to override default.
// SetTransport method sets custom `*http.Transport` or any `http.RoundTripper`
// compatible interface implementation in the resty client.
//
// Please Note:
//
// - If transport is not type of `*http.Transport` then you may not be able to
// take advantage of some of the `resty` client settings.
//
// - It overwrites the resty client transport instance and it's configurations.
//
// **Note:** It overwrites the default resty transport instance and its configurations.
// transport := &http.Transport{
// // somthing like Proxying to httptest.Server, etc...
// Proxy: func(req *http.Request) (*url.URL, error) {
Expand All @@ -646,12 +672,10 @@ func (c *Client) SetOutputDirectory(dirPath string) *Client {
//
// resty.SetTransport(transport)
//
func (c *Client) SetTransport(transport *http.Transport) *Client {
func (c *Client) SetTransport(transport http.RoundTripper) *Client {
if transport != nil {
c.transport = transport
c.httpClient.Transport = c.transport
c.httpClient.Transport = transport
}

return c
}

Expand Down Expand Up @@ -769,12 +793,24 @@ func (c *Client) disableLogPrefix() {
}

// getting TLS client config if not exists then create one
func (c *Client) getTLSConfig() *tls.Config {
if c.transport.TLSClientConfig == nil {
c.transport.TLSClientConfig = &tls.Config{}
c.httpClient.Transport = c.transport
func (c *Client) getTLSConfig() (*tls.Config, error) {
transport, err := c.getTransport()
if err != nil {
return nil, err
}
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
return transport.TLSClientConfig, nil
}

// returns `*http.Transport` currently in use or error
// in case currently used `transport` is not an `*http.Transport`
func (c *Client) getTransport() (*http.Transport, error) {
if transport, ok := c.httpClient.Transport.(*http.Transport); ok {
return transport, nil
}
return c.transport.TLSClientConfig
return nil, errors.New("current transport is not an *http.Transport instance")
}

//
Expand Down
81 changes: 62 additions & 19 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,40 +136,49 @@ func TestClientProxy(t *testing.T) {
c.SetProxy("http://sampleproxy:8888")

resp, err := c.R().Get(ts.URL)
assertEqual(t, true, resp != nil)
assertEqual(t, true, err != nil)
assertNotNil(t, resp)
assertNotNil(t, err)

// Error
c.SetProxy("//not.a.user@%66%6f%6f.com:8888")

resp, err = c.R().
Get(ts.URL)
assertEqual(t, true, err == nil)
assertEqual(t, false, resp == nil)
assertNil(t, err)
assertNotNil(t, resp)
}

func TestSetCertificates(t *testing.T) {
func TestClientSetCertificates(t *testing.T) {
DefaultClient = dc()
SetCertificates(tls.Certificate{})

assertEqual(t, 1, len(DefaultClient.transport.TLSClientConfig.Certificates))
transport, err := DefaultClient.getTransport()

assertNil(t, err)
assertEqual(t, 1, len(transport.TLSClientConfig.Certificates))
}

func TestSetRootCertificate(t *testing.T) {
func TestClientSetRootCertificate(t *testing.T) {
DefaultClient = dc()
SetRootCertificate(getTestDataPath() + "/sample-root.pem")

assertEqual(t, true, DefaultClient.transport.TLSClientConfig.RootCAs != nil)
transport, err := DefaultClient.getTransport()

assertNil(t, err)
assertNotNil(t, transport.TLSClientConfig.RootCAs)
}

func TestSetRootCertificateNotExists(t *testing.T) {
func TestClientSetRootCertificateNotExists(t *testing.T) {
DefaultClient = dc()
SetRootCertificate(getTestDataPath() + "/not-exists-sample-root.pem")

assertEqual(t, true, DefaultClient.transport.TLSClientConfig == nil)
transport, err := DefaultClient.getTransport()

assertNil(t, err)
assertNil(t, transport.TLSClientConfig)
}

func TestOnBeforeRequestModification(t *testing.T) {
func TestClientOnBeforeRequestModification(t *testing.T) {
tc := New()
tc.OnBeforeRequest(func(c *Client, r *Request) error {
r.SetAuthToken("This is test auth token")
Expand All @@ -184,42 +193,46 @@ func TestOnBeforeRequestModification(t *testing.T) {
assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
assertEqual(t, "200 OK", resp.Status())
assertEqual(t, true, resp.Body() != nil)
assertNotNil(t, resp.Body())
assertEqual(t, "TestGet: text response", resp.String())

logResponse(t, resp)
}

func TestSetTransport(t *testing.T) {
func TestClientSetTransport(t *testing.T) {
ts := createGetServer(t)
defer ts.Close()
DefaultClient = dc()

transport := &http.Transport{
// somthing like Proxying to httptest.Server, etc...
// something like Proxying to httptest.Server, etc...
Proxy: func(req *http.Request) (*url.URL, error) {
return url.Parse(ts.URL)
},
}
SetTransport(transport)

assertEqual(t, true, DefaultClient.transport != nil)
transportInUse, err := DefaultClient.getTransport()

assertNil(t, err)

assertEqual(t, true, transport == transportInUse)
}

func TestSetScheme(t *testing.T) {
func TestClientSetScheme(t *testing.T) {
DefaultClient = dc()

SetScheme("http")

assertEqual(t, true, DefaultClient.scheme == "http")
}

func TestSetCookieJar(t *testing.T) {
func TestClientSetCookieJar(t *testing.T) {
DefaultClient = dc()
backupJar := DefaultClient.httpClient.Jar

SetCookieJar(nil)
assertEqual(t, true, DefaultClient.httpClient.Jar == nil)
assertNil(t, DefaultClient.httpClient.Jar)

SetCookieJar(backupJar)
assertEqual(t, true, DefaultClient.httpClient.Jar == backupJar)
Expand Down Expand Up @@ -304,7 +317,10 @@ func TestClientOptions(t *testing.T) {
}

SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
assertEqual(t, true, DefaultClient.transport.TLSClientConfig.InsecureSkipVerify)
transport, transportErr := DefaultClient.getTransport()

assertNil(t, transportErr)
assertEqual(t, true, transport.TLSClientConfig.InsecureSkipVerify)

OnBeforeRequest(func(c *Client, r *Request) error {
c.Log.Println("I'm in Request middleware")
Expand Down Expand Up @@ -363,3 +379,30 @@ func TestClientAllowsGetMethodPayload(t *testing.T) {
assertEqual(t, http.StatusOK, resp.StatusCode())
assertEqual(t, payload, resp.String())
}

func TestClientRoundTripper(t *testing.T) {
c := New()

rt := &CustomRoundTripper{}
c.SetTransport(rt)

ct, err := c.getTransport()
assertNotNil(t, err)
assertNil(t, ct)
assertEqual(t, "current transport is not an *http.Transport instance", err.Error())

c.SetTLSClientConfig(&tls.Config{})
c.SetProxy("http://localhost:9090")
c.RemoveProxy()
c.SetCertificates(tls.Certificate{})
c.SetRootCertificate(getTestDataPath() + "/sample-root.pem")
}

// CustomRoundTripper just for test
type CustomRoundTripper struct {
}

// RoundTrip just for test
func (rt *CustomRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
return &http.Response{}, nil
}
9 changes: 5 additions & 4 deletions default.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ func New() *Client {
JSONMarshal: json.Marshal,
JSONUnmarshal: json.Unmarshal,
httpClient: &http.Client{Jar: cookieJar},
transport: &http.Transport{},
}

c.httpClient.Transport = c.transport
// Default transport
c.SetTransport(&http.Transport{})

// Default redirect policy
c.SetRedirectPolicy(NoRedirectPolicy())
Expand Down Expand Up @@ -254,9 +254,10 @@ func SetOutputDirectory(dirPath string) *Client {
return DefaultClient.SetOutputDirectory(dirPath)
}

// SetTransport method sets custom *http.Transport in the resty client.
// SetTransport method sets custom `*http.Transport` or any `http.RoundTripper`
// compatible interface implementation in the resty client.
// See `Client.SetTransport` for more information.
func SetTransport(transport *http.Transport) *Client {
func SetTransport(transport http.RoundTripper) *Client {
return DefaultClient.SetTransport(transport)
}

Expand Down
Loading

0 comments on commit 8d8edd7

Please sign in to comment.