diff --git a/consumer.go b/consumer.go index f2899fc8..61acde51 100644 --- a/consumer.go +++ b/consumer.go @@ -340,7 +340,8 @@ func (r *Consumer) ConnectToNSQLookupd(addr string) error { return errors.New("no handlers") } - if err := validatedLookupAddr(addr); err != nil { + parsedAddr, err := buildLookupAddr(addr, r.topic) + if err != nil { return err } @@ -348,12 +349,12 @@ func (r *Consumer) ConnectToNSQLookupd(addr string) error { r.mtx.Lock() for _, x := range r.lookupdHTTPAddrs { - if x == addr { + if x == parsedAddr { r.mtx.Unlock() return nil } } - r.lookupdHTTPAddrs = append(r.lookupdHTTPAddrs, addr) + r.lookupdHTTPAddrs = append(r.lookupdHTTPAddrs, parsedAddr) numLookupd := len(r.lookupdHTTPAddrs) r.mtx.Unlock() @@ -383,20 +384,6 @@ func (r *Consumer) ConnectToNSQLookupds(addresses []string) error { return nil } -func validatedLookupAddr(addr string) error { - if strings.Contains(addr, "/") { - _, err := url.Parse(addr) - if err != nil { - return err - } - return nil - } - if !strings.Contains(addr, ":") { - return errors.New("missing port") - } - return nil -} - // poll all known lookup servers every LookupdPollInterval func (r *Consumer) lookupdLoop() { // add some jitter so that multiple consumers discovering the same topic, @@ -446,23 +433,7 @@ func (r *Consumer) nextLookupdEndpoint() string { r.mtx.RUnlock() r.lookupdQueryIndex = (r.lookupdQueryIndex + 1) % num - urlString := addr - if !strings.Contains(urlString, "://") { - urlString = "http://" + addr - } - - u, err := url.Parse(urlString) - if err != nil { - panic(err) - } - if u.Path == "/" || u.Path == "" { - u.Path = "/lookup" - } - - v, err := url.ParseQuery(u.RawQuery) - v.Add("topic", r.topic) - u.RawQuery = v.Encode() - return u.String() + return addr } type lookupResp struct { @@ -659,10 +630,15 @@ func (r *Consumer) DisconnectFromNSQD(addr string) error { // DisconnectFromNSQLookupd removes the specified `nsqlookupd` address // from the list used for periodic discovery. func (r *Consumer) DisconnectFromNSQLookupd(addr string) error { + parsedAddr, err := buildLookupAddr(addr, r.topic) + if err != nil { + return err + } + r.mtx.Lock() defer r.mtx.Unlock() - idx := indexOf(addr, r.lookupdHTTPAddrs) + idx := indexOf(parsedAddr, r.lookupdHTTPAddrs) if idx == -1 { return ErrNotConnected } @@ -1204,3 +1180,28 @@ func (r *Consumer) log(lvl LogLevel, line string, args ...interface{}) { lvl, r.id, r.topic, r.channel, fmt.Sprintf(line, args...))) } + +func buildLookupAddr(addr, topic string) (string, error) { + urlString := addr + if !strings.Contains(urlString, "://") { + urlString = "http://" + addr + } + + u, err := url.Parse(urlString) + if err != nil { + return "", err + } + + if u.Port() == "" { + return "", errors.New("missing port") + } + + if u.Path == "/" || u.Path == "" { + u.Path = "/lookup" + } + + v, err := url.ParseQuery(u.RawQuery) + v.Add("topic", topic) + u.RawQuery = v.Encode() + return u.String(), nil +}