Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consumer: fix panic in nextLookupdEndpoint #321

Merged
merged 1 commit into from
May 31, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 36 additions & 35 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,21 @@ func (r *Consumer) ConnectToNSQLookupd(addr string) error {
return errors.New("no handlers")
}

if err := validatedLookupAddr(addr); err != nil {
parsedAddr, err := buildLookupAddr(addr, r.topic)
if err != nil {
return err
}

atomic.StoreInt32(&r.connectedFlag, 1)

r.mtx.Lock()
for _, x := range r.lookupdHTTPAddrs {
if x == addr {
if x == parsedAddr {
r.mtx.Unlock()
return nil
}
}
r.lookupdHTTPAddrs = append(r.lookupdHTTPAddrs, addr)
r.lookupdHTTPAddrs = append(r.lookupdHTTPAddrs, parsedAddr)
numLookupd := len(r.lookupdHTTPAddrs)
r.mtx.Unlock()

Expand Down Expand Up @@ -383,20 +384,6 @@ func (r *Consumer) ConnectToNSQLookupds(addresses []string) error {
return nil
}

func validatedLookupAddr(addr string) error {
if strings.Contains(addr, "/") {
_, err := url.Parse(addr)
if err != nil {
return err
}
return nil
}
if !strings.Contains(addr, ":") {
return errors.New("missing port")
}
return nil
}

// poll all known lookup servers every LookupdPollInterval
func (r *Consumer) lookupdLoop() {
// add some jitter so that multiple consumers discovering the same topic,
Expand Down Expand Up @@ -446,23 +433,7 @@ func (r *Consumer) nextLookupdEndpoint() string {
r.mtx.RUnlock()
r.lookupdQueryIndex = (r.lookupdQueryIndex + 1) % num

urlString := addr
if !strings.Contains(urlString, "://") {
urlString = "http://" + addr
}

u, err := url.Parse(urlString)
if err != nil {
panic(err)
}
if u.Path == "/" || u.Path == "" {
u.Path = "/lookup"
}

v, err := url.ParseQuery(u.RawQuery)
v.Add("topic", r.topic)
u.RawQuery = v.Encode()
return u.String()
return addr
}

type lookupResp struct {
Expand Down Expand Up @@ -659,10 +630,15 @@ func (r *Consumer) DisconnectFromNSQD(addr string) error {
// DisconnectFromNSQLookupd removes the specified `nsqlookupd` address
// from the list used for periodic discovery.
func (r *Consumer) DisconnectFromNSQLookupd(addr string) error {
parsedAddr, err := buildLookupAddr(addr, r.topic)
if err != nil {
return err
}

r.mtx.Lock()
defer r.mtx.Unlock()

idx := indexOf(addr, r.lookupdHTTPAddrs)
idx := indexOf(parsedAddr, r.lookupdHTTPAddrs)
if idx == -1 {
return ErrNotConnected
}
Expand Down Expand Up @@ -1204,3 +1180,28 @@ func (r *Consumer) log(lvl LogLevel, line string, args ...interface{}) {
lvl, r.id, r.topic, r.channel,
fmt.Sprintf(line, args...)))
}

func buildLookupAddr(addr, topic string) (string, error) {
urlString := addr
if !strings.Contains(urlString, "://") {
urlString = "http://" + addr
}

u, err := url.Parse(urlString)
if err != nil {
return "", err
}

if u.Port() == "" {
return "", errors.New("missing port")
}

if u.Path == "/" || u.Path == "" {
u.Path = "/lookup"
}

v, err := url.ParseQuery(u.RawQuery)
v.Add("topic", topic)
u.RawQuery = v.Encode()
return u.String(), nil
}