Skip to content

Commit

Permalink
Merge pull request #51 from DennyLoko/query-srv
Browse files Browse the repository at this point in the history
Request based on the SRV record of the service
  • Loading branch information
jeevatkm authored Jan 17, 2017
2 parents 2e0c231 + 5213400 commit bbe60b3
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 3 deletions.
2 changes: 1 addition & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func parseRequestBody(c *Client, r *Request) (err error) {

CL:
// by default resty won't set content length, you can if you want to :)
if c.setContentLength || r.setContentLength {
if (c.setContentLength || r.setContentLength) && r.bodyBuf != nil {
r.Header.Set(hdrContentLengthKey, fmt.Sprintf("%d", r.bodyBuf.Len()))
}

Expand Down
45 changes: 43 additions & 2 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@ import (
"encoding/xml"
"fmt"
"io"
"net"
"net/url"
"reflect"
"strings"
)

// SRVRecord holds the data to query the SRV record for the following service
type SRVRecord struct {
Service string
Domain string
}

// SetHeader method is to set a single header field and its value in the current request.
// Example: To set `Content-Type` and `Accept` as `application/json`.
// resty.R().
Expand Down Expand Up @@ -345,6 +352,16 @@ func (r *Request) SetProxy(proxyURL string) *Request {
return r
}

// SetSRV method sets the details to query the service SRV record and execute the
// request.
// resty.R().
// SetSRV(SRVRecord{"web", "testservice.com"}).
// Get("/get")
func (r *Request) SetSRV(srv *SRVRecord) *Request {
r.SRV = srv
return r
}

//
// HTTP verb method starts here
//
Expand Down Expand Up @@ -389,22 +406,34 @@ func (r *Request) Patch(url string) (*Response, error) {
// resp, err := resty.R().Execute(resty.GET, "http://httpbin.org/get")
//
func (r *Request) Execute(method, url string) (*Response, error) {
var addrs []*net.SRV
var err error

if r.isMultiPart && !(method == MethodPost || method == MethodPut) {
return nil, fmt.Errorf("Multipart content is not allowed in HTTP verb [%v]", method)
}

if r.SRV != nil {
_, addrs, err = net.LookupSRV(r.SRV.Service, "tcp", r.SRV.Domain)
if err != nil {
return nil, err
}
}

r.Method = method
r.URL = url
r.URL = r.selectAddr(addrs, url, 0)

if r.client.RetryCount == 0 {
return r.client.execute(r)
}

var resp *Response
var err error
attempt := 0
_ = Backoff(func() (*Response, error) {
attempt++

r.URL = r.selectAddr(addrs, url, attempt)

resp, err = r.client.execute(r)
if err != nil {
r.client.Log.Printf("ERROR [%v] Attempt [%v]", err, attempt)
Expand Down Expand Up @@ -465,3 +494,15 @@ func (r *Request) fmtBodyString() (body string) {

return
}

func (r *Request) selectAddr(addrs []*net.SRV, path string, attempt int) string {
if addrs == nil {
return path
}

idx := attempt % len(addrs)
domain := strings.TrimRight(addrs[idx].Target, ".")
path = strings.TrimLeft(path, "/")

return fmt.Sprintf("%s://%s:%d/%s", r.client.scheme, domain, addrs[idx].Port, path)
}
1 change: 1 addition & 0 deletions request16.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type Request struct {
Error interface{}
Time time.Time
RawRequest *http.Request
SRV *SRVRecord

client *Client
bodyBuf *bytes.Buffer
Expand Down
1 change: 1 addition & 0 deletions request17.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Request struct {
Error interface{}
Time time.Time
RawRequest *http.Request
SRV *SRVRecord

client *Client
bodyBuf *bytes.Buffer
Expand Down
25 changes: 25 additions & 0 deletions resty_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,31 @@ func TestClientOptions(t *testing.T) {
SetLogger(ioutil.Discard)
}

func TestSRV(t *testing.T) {
c := dc().
SetRedirectPolicy(FlexibleRedirectPolicy(20)).
SetScheme("http")

r := c.R().
SetSRV(&SRVRecord{"xmpp-server", "google.com"})

assertEqual(t, "xmpp-server", r.SRV.Service)
assertEqual(t, "google.com", r.SRV.Domain)

resp, err := r.Get("/")
assertError(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())
}

func TestSRVInvalidService(t *testing.T) {
_, err := R().
SetSRV(&SRVRecord{"nonexistantservice", "sampledomain"}).
Get("/")

assertEqual(t, true, (err != nil))
assertEqual(t, true, strings.Contains(err.Error(), "no such host"))
}

func getTestDataPath() string {
pwd, _ := os.Getwd()
return pwd + "/test-data"
Expand Down

0 comments on commit bbe60b3

Please sign in to comment.