diff --git a/src/main/java/redis/clients/jedis/JedisPubSubBase.java b/src/main/java/redis/clients/jedis/JedisPubSubBase.java index 7092680e33..552310e4de 100644 --- a/src/main/java/redis/clients/jedis/JedisPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisPubSubBase.java @@ -172,7 +172,7 @@ private void process() { } else { throw new JedisException("Unknown message type: " + reply); } - } while (isSubscribed()); + } while (!Thread.currentThread().isInterrupted() && isSubscribed()); // /* Invalidate instance since this thread is no longer listening */ // this.client = null; diff --git a/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java b/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java index f0a251f61f..2b2ce944fe 100644 --- a/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java +++ b/src/main/java/redis/clients/jedis/JedisShardedPubSubBase.java @@ -99,7 +99,7 @@ private void process() { } else { throw new JedisException("Unknown message type: " + reply); } - } while (isSubscribed()); + } while (!Thread.currentThread().isInterrupted() && isSubscribed()); // /* Invalidate instance since this thread is no longer listening */ // this.client = null; diff --git a/src/test/java/redis/clients/jedis/JedisPubSubBaseTest.java b/src/test/java/redis/clients/jedis/JedisPubSubBaseTest.java new file mode 100644 index 0000000000..a7910bd6a3 --- /dev/null +++ b/src/test/java/redis/clients/jedis/JedisPubSubBaseTest.java @@ -0,0 +1,59 @@ +package redis.clients.jedis; + +import junit.framework.TestCase; +import redis.clients.jedis.util.SafeEncoder; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static redis.clients.jedis.Protocol.ResponseKeyword.MESSAGE; +import static redis.clients.jedis.Protocol.ResponseKeyword.SUBSCRIBE; + +public class JedisPubSubBaseTest extends TestCase { + + public void testProceed_givenThreadInterrupt_exitLoop() throws InterruptedException { + // setup + final JedisPubSubBase pubSub = new JedisPubSubBase() { + + @Override + public void onMessage(String channel, String message) { + fail("this should not happen when thread is interrupted"); + } + + @Override + protected String encode(byte[] raw) { + return SafeEncoder.encode(raw); + } + }; + + final Connection mockConnection = mock(Connection.class); + final List mockSubscribe = Arrays.asList( + SUBSCRIBE.getRaw(), "channel".getBytes(), 1L + ); + final List mockResponse = Arrays.asList( + MESSAGE.getRaw(), "channel".getBytes(), "message".getBytes() + ); + + when(mockConnection.getUnflushedObject()). + + thenReturn(mockSubscribe, mockResponse); + + + final CountDownLatch countDownLatch = new CountDownLatch(1); + // action + final Thread thread = new Thread(() -> { + Thread.currentThread().interrupt(); + pubSub.proceed(mockConnection, "channel"); + + countDownLatch.countDown(); + }); + thread.start(); + + assertTrue(countDownLatch.await(10, TimeUnit.MILLISECONDS)); + + } +} diff --git a/src/test/java/redis/clients/jedis/JedisShardedPubSubBaseTest.java b/src/test/java/redis/clients/jedis/JedisShardedPubSubBaseTest.java new file mode 100644 index 0000000000..fb1ecdd87a --- /dev/null +++ b/src/test/java/redis/clients/jedis/JedisShardedPubSubBaseTest.java @@ -0,0 +1,56 @@ +package redis.clients.jedis; + +import junit.framework.TestCase; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static redis.clients.jedis.Protocol.ResponseKeyword.SMESSAGE; +import static redis.clients.jedis.Protocol.ResponseKeyword.SSUBSCRIBE; + +public class JedisShardedPubSubBaseTest extends TestCase { + + public void testProceed_givenThreadInterrupt_exitLoop() throws InterruptedException { + // setup + final JedisShardedPubSubBase pubSub = new JedisShardedPubSubBase() { + + @Override + public void onSMessage(String channel, String message) { + fail("this should not happen when thread is interrupted"); + } + + @Override + protected String encode(byte[] raw) { + return new String(raw); + } + + }; + + final Connection mockConnection = mock(Connection.class); + final List mockSubscribe = Arrays.asList( + SSUBSCRIBE.getRaw(), "channel".getBytes(), 1L + ); + final List mockResponse = Arrays.asList( + SMESSAGE.getRaw(), "channel".getBytes(), "message".getBytes() + ); + when(mockConnection.getUnflushedObject()).thenReturn(mockSubscribe, mockResponse); + + + final CountDownLatch countDownLatch = new CountDownLatch(1); + // action + final Thread thread = new Thread(() -> { + Thread.currentThread().interrupt(); + pubSub.proceed(mockConnection, "channel"); + + countDownLatch.countDown(); + }); + thread.start(); + + assertTrue(countDownLatch.await(10, TimeUnit.MILLISECONDS)); + + } +} \ No newline at end of file