From ba6c2b4739ccdb0a9d8df0959a2c4b550debe6c8 Mon Sep 17 00:00:00 2001
From: Alp Aker <alp.aker@buzzfeed.com>
Date: Tue, 9 May 2017 12:18:27 -0400
Subject: [PATCH] reader:  Don't decrement total_rdy on message receipt. 
 Adjust RDY redistribution logic accordingly.

This brings reader behavior into agreement with nsqd behavior (compare nsqio/nsq#404)
and removes an opportunity for max_in_flight violations (#177).
---
 nsq/async.py          |  1 -
 nsq/reader.py         | 56 +++++++++++++++++++++----------------------
 tests/test_backoff.py |  1 +
 tests/test_reader.py  |  2 +-
 4 files changed, 30 insertions(+), 30 deletions(-)

diff --git a/nsq/async.py b/nsq/async.py
index ab88c51..d44500d 100644
--- a/nsq/async.py
+++ b/nsq/async.py
@@ -486,7 +486,6 @@ def _on_data(self, data, **kwargs):
         frame, data = protocol.unpack_response(data)
         if frame == protocol.FRAME_TYPE_MESSAGE:
             self.last_msg_timestamp = time.time()
-            self.rdy = max(self.rdy - 1, 0)
             self.in_flight += 1
 
             message = protocol.decode_message(data)
diff --git a/nsq/reader.py b/nsq/reader.py
index 6e11073..69f5538 100644
--- a/nsq/reader.py
+++ b/nsq/reader.py
@@ -317,21 +317,7 @@ def _on_message(self, conn, message, **kwargs):
             logger.exception('[%s:%s] failed to handle_message() %r', conn.id, self.name, message)
 
     def _handle_message(self, conn, message):
-        self.total_rdy = max(self.total_rdy - 1, 0)
-
-        rdy_conn = conn
-        if len(self.conns) > self.max_in_flight and time.time() - self.random_rdy_ts > 30:
-            # if all connections aren't getting RDY
-            # occsionally randomize which connection gets RDY
-            self.random_rdy_ts = time.time()
-            conns_with_no_rdy = [c for c in itervalues(self.conns) if not c.rdy]
-            if conns_with_no_rdy:
-                rdy_conn = random.choice(conns_with_no_rdy)
-                if rdy_conn is not conn:
-                    logger.info('[%s:%s] redistributing RDY to %s',
-                                conn.id, self.name, rdy_conn.id)
-
-        self._maybe_update_rdy(rdy_conn)
+        self._maybe_update_rdy(conn)
 
         success = False
         try:
@@ -358,7 +344,9 @@ def _maybe_update_rdy(self, conn):
         if self.backoff_timer.get_interval() or self.max_in_flight == 0:
             return
 
-        if conn.rdy <= 1 or conn.rdy < int(conn.last_rdy * 0.25):
+        # On a new connection or in backoff we start with a tentative RDY count
+        # of 1.  After successfully receiving a first message we go to full throttle.
+        if conn.rdy == 1:
             self._send_rdy(conn, self._connection_max_in_flight())
 
     def _finish_backoff_block(self):
@@ -452,15 +440,10 @@ def _send_rdy(self, conn, value):
         if value > conn.max_rdy_count:
             value = conn.max_rdy_count
 
-        if (self.total_rdy + value) > self.max_in_flight:
-            if not conn.rdy:
-                # if we're going from RDY 0 to non-0 and we couldn't because
-                # of the configured max in flight, try again
-                rdy_retry_callback = functools.partial(self._rdy_retry, conn, value)
-                conn.rdy_timeout = self.io_loop.add_timeout(time.time() + 5, rdy_retry_callback)
+        new_rdy = max(self.total_rdy - conn.rdy + value, 0)
+        if new_rdy > self.max_in_flight:
             return
 
-        new_rdy = max(self.total_rdy - conn.rdy + value, 0)
         if conn.send_rdy(value):
             self.total_rdy = new_rdy
 
@@ -665,10 +648,27 @@ def _redistribute_rdy_state(self):
                     logger.info('[%s:%s] idle connection, giving up RDY count', conn.id, self.name)
                     self._send_rdy(conn, 0)
 
+            conns = self.conns.values()
+
+            in_flight_or_rdy = len([c for c in conns if c.in_flight or c.rdy])
             if backoff_interval:
-                max_in_flight = 1 - self.total_rdy
+                available_rdy = max(0, 1 - in_flight_or_rdy)
             else:
-                max_in_flight = self.max_in_flight - self.total_rdy
+                available_rdy = self.max_in_flight - in_flight_or_rdy
+
+            # if moving any connections from RDY 0 to non-0 would violate in-flight constraints,
+            # set RDY 0 on some connection with msgs in flight so that a later redistribution
+            # round can proceed and we don't stay pinned to the same connections.
+            #
+            # if nothing's in flight, then we have connections with RDY 1 that are still
+            # waiting to hit the idle timeout, in which case it's ok to do nothing.
+            if not available_rdy:
+                try:
+                    c = random.choice([c for c in conns if c.in_flight])
+                    logger.info('[%s:%s] too many msgs in flight, giving up RDY count', c.id, self.name)
+                    self._send_rdy(c, 0)
+                except IndexError:
+                    pass
 
             # randomly walk the list of possible connections and send RDY 1 (up to our
             # calculated "max_in_flight").  We only need to send RDY 1 because in both
@@ -677,9 +677,9 @@ def _redistribute_rdy_state(self):
             # We also don't attempt to avoid the connections who previously might have had RDY 1
             # because it would be overly complicated and not actually worth it (ie. given enough
             # redistribution rounds it doesn't matter).
-            possible_conns = list(self.conns.values())
-            while possible_conns and max_in_flight:
-                max_in_flight -= 1
+            possible_conns = [c for c in conns if not (c.in_flight or c.rdy)]
+            while possible_conns and available_rdy:
+                available_rdy -= 1
                 conn = possible_conns.pop(random.randrange(len(possible_conns)))
                 logger.info('[%s:%s] redistributing RDY', conn.id, self.name)
                 self._send_rdy(conn, 1)
diff --git a/tests/test_backoff.py b/tests/test_backoff.py
index 3db5f08..b216da7 100644
--- a/tests/test_backoff.py
+++ b/tests/test_backoff.py
@@ -52,6 +52,7 @@ def _get_conn(reader):
 
 def _send_message(conn):
     msg = _get_message(conn)
+    conn.in_flight += 1
     conn.trigger(event.MESSAGE, conn=conn, message=msg)
     return msg
 
diff --git a/tests/test_reader.py b/tests/test_reader.py
index 7d7f09e..e6e46aa 100644
--- a/tests/test_reader.py
+++ b/tests/test_reader.py
@@ -160,7 +160,7 @@ def test_conn_messages(self):
 
         def _on_message(*args, **kwargs):
             self.msg_count += 1
-            if c.rdy == 0:
+            if c.in_flight == 5:
                 self.stop()
 
         def _on_ready(*args, **kwargs):