diff --git a/api_request.go b/api_request.go index d4e389c9..d3e97596 100644 --- a/api_request.go +++ b/api_request.go @@ -3,7 +3,6 @@ package nsq import ( "encoding/json" "fmt" - "io" "io/ioutil" "net" "net/http" @@ -46,12 +45,15 @@ type wrappedResp struct { } // stores the result in the value pointed to by ret(must be a pointer) -func apiRequestNegotiateV1(method string, endpoint string, body io.Reader, ret interface{}) error { +func apiRequestNegotiateV1(method string, endpoint string, headers http.Header, ret interface{}) error { httpclient := &http.Client{Transport: newDeadlineTransport(2 * time.Second)} - req, err := http.NewRequest(method, endpoint, body) + req, err := http.NewRequest(method, endpoint, nil) if err != nil { return err } + for k, v := range headers { + req.Header[k] = v + } req.Header.Add("Accept", "application/vnd.nsq; version=1.0") diff --git a/config.go b/config.go index 05a81575..d799c5ff 100644 --- a/config.go +++ b/config.go @@ -176,8 +176,10 @@ type Config struct { // The server-side message timeout for messages delivered to this client MsgTimeout time.Duration `opt:"msg_timeout" min:"0"` - // secret for nsqd authentication (requires nsqd 0.2.29+) + // Secret for nsqd authentication (requires nsqd 0.2.29+) AuthSecret string `opt:"auth_secret"` + // Use AuthSecret as 'Authorization: Bearer {AuthSecret}' on lookupd queries + LookupdAuthorization bool `opt:"skip_lookupd_authorization" default:"true"` } // NewConfig returns a new default nsq configuration. diff --git a/consumer.go b/consumer.go index 04cb1f62..f2899fc8 100644 --- a/consumer.go +++ b/consumer.go @@ -8,6 +8,7 @@ import ( "math" "math/rand" "net" + "net/http" "net/url" "os" "strconv" @@ -492,7 +493,11 @@ retry: r.log(LogLevelInfo, "querying nsqlookupd %s", endpoint) var data lookupResp - err := apiRequestNegotiateV1("GET", endpoint, nil, &data) + headers := make(http.Header) + if r.config.AuthSecret != "" && r.config.LookupdAuthorization { + headers.Set("Authorization", fmt.Sprintf("Bearer %s", r.config.AuthSecret)) + } + err := apiRequestNegotiateV1("GET", endpoint, headers, &data) if err != nil { r.log(LogLevelError, "error querying nsqlookupd (%s) - %s", endpoint, err) retries++ diff --git a/consumer_test.go b/consumer_test.go index 5eb5a961..dcb34df0 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -10,6 +10,7 @@ import ( "log" "net" "net/http" + "net/http/httptest" "strconv" "strings" "testing" @@ -125,6 +126,36 @@ func TestConsumerTLSClientCert(t *testing.T) { }) } +func TestConsumerLookupdAuthorization(t *testing.T) { + // confirm that LookupAuthorization = true sets Authorization header on lookudp call + config := NewConfig() + config.AuthSecret = "AuthSecret" + topicName := "auth" + strconv.Itoa(int(time.Now().Unix())) + q, _ := NewConsumer(topicName, "ch", config) + q.SetLogger(newTestLogger(t), LogLevelDebug) + + var req bool + lookupd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req = true + if h := r.Header.Get("Authorization"); h != "Bearer AuthSecret" { + t.Errorf("got Auth header %q", h) + } + w.WriteHeader(404) + })) + defer lookupd.Close() + + h := &MyTestHandler{ + t: t, + q: q, + } + q.AddHandler(h) + + q.ConnectToNSQLookupd(lookupd.URL) + if req == false { + t.Errorf("lookupd call not completed") + } +} + func TestConsumerTLSClientCertViaSet(t *testing.T) { consumerTest(t, func(c *Config) { c.Set("tls_v1", true)