diff --git a/src/main/java/redis/clients/jedis/Connection.java b/src/main/java/redis/clients/jedis/Connection.java index 68eaee3b74..2860866c6e 100644 --- a/src/main/java/redis/clients/jedis/Connection.java +++ b/src/main/java/redis/clients/jedis/Connection.java @@ -12,6 +12,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.function.Supplier; import redis.clients.jedis.Protocol.Command; @@ -41,6 +42,8 @@ public class Connection implements Closeable { private boolean broken = false; private boolean strValActive; private String strVal; + protected String server; + protected String version; public Connection() { this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT); @@ -453,12 +456,12 @@ protected void initializeFromClientConfig(final JedisClientConfig config) { final RedisCredentialsProvider redisCredentialsProvider = (RedisCredentialsProvider) credentialsProvider; try { redisCredentialsProvider.prepare(); - helloOrAuth(protocol, redisCredentialsProvider.get()); + helloAndAuth(protocol, redisCredentialsProvider.get()); } finally { redisCredentialsProvider.cleanUp(); } } else { - helloOrAuth(protocol, credentialsProvider != null ? credentialsProvider.get() + helloAndAuth(protocol, credentialsProvider != null ? credentialsProvider.get() : new DefaultRedisCredentials(config.getUser(), config.getPassword())); } @@ -517,50 +520,56 @@ protected void initializeFromClientConfig(final JedisClientConfig config) { } } - private void helloOrAuth(final RedisProtocol protocol, final RedisCredentials credentials) { - - if (credentials == null || credentials.getPassword() == null) { - if (protocol != null) { - sendCommand(Command.HELLO, encode(protocol.version())); - getOne(); + private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials credentials) { + Map helloResult = null; + if (protocol != null && credentials != null && credentials.getUser() != null) { + byte[] rawPass = encodeToBytes(credentials.getPassword()); + try { + helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(), encode(credentials.getUser()), rawPass); + } finally { + Arrays.fill(rawPass, (byte) 0); // clear sensitive data } - return; + } else { + auth(credentials); + helloResult = protocol == null ? null : hello(encode(protocol.version())); + } + if (helloResult != null) { + server = (String) helloResult.get("server"); + version = (String) helloResult.get("version"); } - // Source: https://stackoverflow.com/a/9670279/4021802 - ByteBuffer passBuf = Protocol.CHARSET.encode(CharBuffer.wrap(credentials.getPassword())); - byte[] rawPass = Arrays.copyOfRange(passBuf.array(), passBuf.position(), passBuf.limit()); - Arrays.fill(passBuf.array(), (byte) 0); // clear sensitive data + // clearing 'char[] credentials.getPassword()' should be + // handled in RedisCredentialsProvider.cleanUp() + } + private void auth(RedisCredentials credentials) { + if (credentials == null || credentials.getPassword() == null) { + return; + } + byte[] rawPass = encodeToBytes(credentials.getPassword()); try { - /// actual HELLO or AUTH --> - if (protocol != null) { - if (credentials.getUser() != null) { - sendCommand(Command.HELLO, encode(protocol.version()), - Keyword.AUTH.getRaw(), encode(credentials.getUser()), rawPass); - getOne(); // Map - } else { - sendCommand(Command.AUTH, rawPass); - getStatusCodeReply(); // OK - sendCommand(Command.HELLO, encode(protocol.version())); - getOne(); // Map - } - } else { // protocol == null - if (credentials.getUser() != null) { - sendCommand(Command.AUTH, encode(credentials.getUser()), rawPass); - } else { - sendCommand(Command.AUTH, rawPass); - } - getStatusCodeReply(); // OK + if (credentials.getUser() == null) { + sendCommand(Command.AUTH, rawPass); + } else { + sendCommand(Command.AUTH, encode(credentials.getUser()), rawPass); } - /// <-- actual HELLO or AUTH } finally { - Arrays.fill(rawPass, (byte) 0); // clear sensitive data } + getStatusCodeReply(); + } - // clearing 'char[] credentials.getPassword()' should be - // handled in RedisCredentialsProvider.cleanUp() + protected Map hello(byte[]... args) { + sendCommand(Command.HELLO, args); + return BuilderFactory.ENCODED_OBJECT_MAP.build(getOne()); + } + + protected byte[] encodeToBytes(char[] chars) { + // Source: https://stackoverflow.com/a/9670279/4021802 + ByteBuffer passBuf = Protocol.CHARSET.encode(CharBuffer.wrap(chars)); + byte[] rawPass = Arrays.copyOfRange(passBuf.array(), passBuf.position(), passBuf.limit()); + Arrays.fill(passBuf.array(), (byte) 0); // clear sensitive data + return rawPass; } public String select(final int index) { diff --git a/src/main/java/redis/clients/jedis/csc/AbstractCache.java b/src/main/java/redis/clients/jedis/csc/AbstractCache.java index fc936b5baf..84b4d2ef81 100644 --- a/src/main/java/redis/clients/jedis/csc/AbstractCache.java +++ b/src/main/java/redis/clients/jedis/csc/AbstractCache.java @@ -193,6 +193,10 @@ public CacheStats getAndResetStats() { return result; } + @Override + public boolean compatibilityMode() { + return false; + } // End of Cache interface methods // abstract methods to be implemented by the concrete classes diff --git a/src/main/java/redis/clients/jedis/csc/Cache.java b/src/main/java/redis/clients/jedis/csc/Cache.java index 49413bc0de..0bf4592b59 100644 --- a/src/main/java/redis/clients/jedis/csc/Cache.java +++ b/src/main/java/redis/clients/jedis/csc/Cache.java @@ -105,4 +105,9 @@ public interface Cache { * @return The statistics of the cache */ CacheStats getAndResetStats(); + + /** + * @return The compatibility of cache against different Redis versions + */ + boolean compatibilityMode(); } diff --git a/src/main/java/redis/clients/jedis/csc/CacheConnection.java b/src/main/java/redis/clients/jedis/csc/CacheConnection.java index c19dd319ec..f157d95a94 100644 --- a/src/main/java/redis/clients/jedis/csc/CacheConnection.java +++ b/src/main/java/redis/clients/jedis/csc/CacheConnection.java @@ -16,6 +16,8 @@ public class CacheConnection extends Connection { private final Cache cache; private ReentrantLock lock; + private static final String REDIS = "redis"; + private static final String MIN_REDIS_VERSION = "7.4"; public CacheConnection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig, Cache cache) { super(socketFactory, clientConfig); @@ -23,6 +25,13 @@ public CacheConnection(final JedisSocketFactory socketFactory, JedisClientConfig if (protocol != RedisProtocol.RESP3) { throw new JedisException("Client side caching is only supported with RESP3."); } + if (!cache.compatibilityMode()) { + RedisVersion current = new RedisVersion(version); + RedisVersion required = new RedisVersion(MIN_REDIS_VERSION); + if (!REDIS.equals(server) || current.compareTo(required) < 0) { + throw new JedisException(String.format("Client side caching is only supported with 'Redis %s' or later.", MIN_REDIS_VERSION)); + } + } this.cache = Objects.requireNonNull(cache); initializeClientSideCache(); } diff --git a/src/main/java/redis/clients/jedis/csc/RedisVersion.java b/src/main/java/redis/clients/jedis/csc/RedisVersion.java new file mode 100644 index 0000000000..2daf6393c7 --- /dev/null +++ b/src/main/java/redis/clients/jedis/csc/RedisVersion.java @@ -0,0 +1,41 @@ +package redis.clients.jedis.csc; + +import java.util.Arrays; + +class RedisVersion implements Comparable { + + private String version; + private Integer[] numbers; + + public RedisVersion(String version) { + if (version == null) throw new IllegalArgumentException("Version can not be null"); + this.version = version; + this.numbers = Arrays.stream(version.split("\\.")).map(n -> Integer.parseInt(n)).toArray(Integer[]::new); + } + + @Override + public int compareTo(RedisVersion other) { + int max = Math.max(this.numbers.length, other.numbers.length); + for (int i = 0; i < max; i++) { + int thisNumber = this.numbers.length > i ? this.numbers[i]:0; + int otherNumber = other.numbers.length > i ? other.numbers[i]:0; + if (thisNumber < otherNumber) return -1; + if (thisNumber > otherNumber) return 1; + } + return 0; + } + + @Override + public String toString() { + return this.version; + } + + @Override + public boolean equals(Object that) { + if (this == that) return true; + if (that == null) return false; + if (this.getClass() != that.getClass()) return false; + return this.compareTo((RedisVersion) that) == 0; + } + +} diff --git a/src/test/java/redis/clients/jedis/csc/VersionTest.java b/src/test/java/redis/clients/jedis/csc/VersionTest.java new file mode 100644 index 0000000000..b154e88a05 --- /dev/null +++ b/src/test/java/redis/clients/jedis/csc/VersionTest.java @@ -0,0 +1,30 @@ +package redis.clients.jedis.csc; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +public class VersionTest { + + @Test + public void compareSameVersions() { + RedisVersion a = new RedisVersion("5.2.4"); + RedisVersion b = new RedisVersion("5.2.4"); + assertEquals(a, b); + + RedisVersion c = new RedisVersion("5.2.0.0"); + RedisVersion d = new RedisVersion("5.2"); + assertEquals(a, b); + } + + @Test + public void compareDifferentVersions() { + RedisVersion a = new RedisVersion("5.2.4"); + RedisVersion b = new RedisVersion("5.1.4"); + assertEquals(1, a.compareTo(b)); + + RedisVersion c = new RedisVersion("5.2.4"); + RedisVersion d = new RedisVersion("5.2.5"); + assertEquals(-1, c.compareTo(d)); + } +}