diff --git a/consumer.go b/consumer.go index 7438334b..0173a8ec 100644 --- a/consumer.go +++ b/consumer.go @@ -100,9 +100,11 @@ type Consumer struct { rdyRetryMtx sync.RWMutex rdyRetryTimers map[string]*time.Timer - pendingConnections map[string]bool + pendingConnections map[string]*Conn connections map[string]*Conn + nsqdTCPAddrs []string + // used at connection close to force a possible reconnect lookupdRecheckChan chan int lookupdHTTPAddrs []string @@ -152,7 +154,7 @@ func NewConsumer(topic string, channel string, config *Config) (*Consumer, error incomingMessages: make(chan *Message), rdyRetryTimers: make(map[string]*time.Timer), - pendingConnections: make(map[string]bool), + pendingConnections: make(map[string]*Conn), connections: make(map[string]*Conn), lookupdRecheckChan: make(chan int, 1), @@ -455,21 +457,25 @@ func (r *Consumer) ConnectToNSQD(addr string) error { atomic.StoreInt32(&r.connectedFlag, 1) + conn := NewConn(addr, &r.config, &consumerConnDelegate{r}) + conn.SetLogger(r.logger, r.logLvl, + fmt.Sprintf("%3d [%s/%s] (%%s)", r.id, r.topic, r.channel)) + + r.mtx.Lock() _, pendingOk := r.pendingConnections[addr] - r.mtx.RLock() _, ok := r.connections[addr] - r.mtx.RUnlock() - if ok || pendingOk { + r.mtx.Unlock() return ErrAlreadyConnected } + if !pendingOk { + r.pendingConnections[addr] = conn + } + r.nsqdTCPAddrs = append(r.nsqdTCPAddrs, addr) + r.mtx.Unlock() r.log(LogLevelInfo, "(%s) connecting to nsqd", addr) - conn := NewConn(addr, &r.config, &consumerConnDelegate{r}) - conn.SetLogger(r.logger, r.logLvl, - fmt.Sprintf("%3d [%s/%s] (%%s)", r.id, r.topic, r.channel)) - cleanupConnection := func() { r.mtx.Lock() delete(r.pendingConnections, addr) @@ -477,8 +483,6 @@ func (r *Consumer) ConnectToNSQD(addr string) error { conn.Close() } - r.pendingConnections[addr] = true - resp, err := conn.Connect() if err != nil { cleanupConnection() @@ -501,8 +505,8 @@ func (r *Consumer) ConnectToNSQD(addr string) error { conn, r.topic, r.channel, err.Error()) } - delete(r.pendingConnections, addr) r.mtx.Lock() + delete(r.pendingConnections, addr) r.connections[addr] = conn r.mtx.Unlock() @@ -514,6 +518,57 @@ func (r *Consumer) ConnectToNSQD(addr string) error { return nil } +func indexOf(n string, h []string) int { + for i, a := range h { + if n == a { + return i + } + } + return -1 +} + +func (r *Consumer) DisconnectFromNSQD(addr string) error { + r.mtx.Lock() + defer r.mtx.Unlock() + + idx := indexOf(addr, r.nsqdTCPAddrs) + if idx == -1 { + return ErrNotConnected + } + + // slice delete + r.nsqdTCPAddrs = append(r.nsqdTCPAddrs[:idx], r.nsqdTCPAddrs[idx+1:]...) + + pendingConn, pendingOk := r.pendingConnections[addr] + conn, ok := r.connections[addr] + + if ok { + conn.Close() + } else if pendingOk { + pendingConn.Close() + } + + return nil +} + +func (r *Consumer) DisconnectFromNSQLookupd(addr string) error { + r.mtx.Lock() + defer r.mtx.Unlock() + + idx := indexOf(addr, r.lookupdHTTPAddrs) + if idx == -1 { + return ErrNotConnected + } + + if len(r.lookupdHTTPAddrs) == 1 { + return errors.New(fmt.Sprintf("cannot disconnect from only remaining nsqlookupd HTTP address %s", addr)) + } + + r.lookupdHTTPAddrs = append(r.lookupdHTTPAddrs[:idx], r.lookupdHTTPAddrs[idx+1:]...) + + return nil +} + func (r *Consumer) onConnMessage(c *Conn, msg *Message) { atomic.AddInt64(&r.totalRdyCount, -1) atomic.AddUint64(&r.messagesReceived, 1) @@ -664,22 +719,26 @@ func (r *Consumer) onConnClose(c *Conn) { } // we were the last one (and stopping) - if left == 0 && atomic.LoadInt32(&r.stopFlag) == 1 { - r.stopHandlers() + if atomic.LoadInt32(&r.stopFlag) == 1 { + if left == 0 { + r.stopHandlers() + } return } r.mtx.RLock() numLookupd := len(r.lookupdHTTPAddrs) + reconnect := indexOf(c.String(), r.nsqdTCPAddrs) >= 0 r.mtx.RUnlock() - if numLookupd != 0 && atomic.LoadInt32(&r.stopFlag) == 0 { + if numLookupd > 0 { // trigger a poll of the lookupd select { case r.lookupdRecheckChan <- 1: default: } - } else if numLookupd == 0 && atomic.LoadInt32(&r.stopFlag) == 0 { - // there are no lookupd, try to reconnect after a bit + } else if reconnect { + // there are no lookupd and we still have this nsqd TCP address in our list... + // try to reconnect after a bit go func(addr string) { for { r.log(LogLevelInfo, "(%s) re-connecting in 15 seconds...", addr) @@ -687,6 +746,13 @@ func (r *Consumer) onConnClose(c *Conn) { if atomic.LoadInt32(&r.stopFlag) == 1 { break } + r.mtx.RLock() + reconnect := indexOf(addr, r.nsqdTCPAddrs) >= 0 + r.mtx.RUnlock() + if !reconnect { + r.log(LogLevelWarning, "(%s) skipped reconnect after removal...", addr) + return + } err := r.ConnectToNSQD(addr) if err != nil && err != ErrAlreadyConnected { r.log(LogLevelError, "(%s) error connecting to nsqd - %s", addr, err) diff --git a/consumer_test.go b/consumer_test.go index 2e66520c..467206cc 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -158,7 +158,7 @@ func consumerTest(t *testing.T, cb func(c *Config)) { } topicName = topicName + strconv.Itoa(int(time.Now().Unix())) q, _ := NewConsumer(topicName, "ch", config) - q.SetLogger(nullLogger, LogLevelInfo) + // q.SetLogger(nullLogger, LogLevelInfo) h := &MyTestHandler{ t: t, @@ -182,6 +182,21 @@ func consumerTest(t *testing.T, cb func(c *Config)) { t.Fatal("should not be able to connect to the same NSQ twice") } + err = q.DisconnectFromNSQD("1.2.3.4:4150") + if err == nil { + t.Fatal("should not be able to disconnect from an unknown nsqd") + } + + err = q.ConnectToNSQD("1.2.3.4:4150") + if err == nil { + t.Fatal("should not be able to connect to non-existent nsqd") + } + + err = q.DisconnectFromNSQD("1.2.3.4:4150") + if err != nil { + t.Fatal("should be able to disconnect from an nsqd - " + err.Error()) + } + <-q.StopChan if h.messagesReceived != 8 || h.messagesSent != 4 {