diff --git a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/codec/BooleanCodec.java b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/codec/BooleanCodec.java index c8f641cc..15dbdc19 100644 --- a/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/codec/BooleanCodec.java +++ b/r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/codec/BooleanCodec.java @@ -40,27 +40,33 @@ private BooleanCodec() { @Override public Boolean decode(ByteBuf value, MySqlReadableMetadata metadata, Class target, boolean binary, CodecContext context) { - if (!value.isReadable()) { - return createFromLong(0); - } + MySqlType dataType = metadata.getType(); + + if (dataType == MySqlType.VARCHAR) { + if (!value.isReadable()) { + return createFromLong(0); + } - String s = value.toString(metadata.getCharCollation(context).getCharset()); - - if (s.equalsIgnoreCase("Y") || s.equalsIgnoreCase("yes") || - s.equalsIgnoreCase("T") || s.equalsIgnoreCase("true")) { - return createFromLong(1); - } else if (s.equalsIgnoreCase("N") || s.equalsIgnoreCase("no") || - s.equalsIgnoreCase("F") || s.equalsIgnoreCase("false")) { - return createFromLong(0); - } else if (s.contains("e") || s.contains("E") || s.matches("-?\\d*\\.\\d*")) { - return createFromDouble(Double.parseDouble(s)); - } else if (s.matches("-?\\d+")) { - if (!CodecUtils.isGreaterThanLongMax(s)) { - return createFromLong(CodecUtils.parseLong(value)); + String s = value.toString(metadata.getCharCollation(context).getCharset()); + + if (s.equalsIgnoreCase("Y") || s.equalsIgnoreCase("yes") || + s.equalsIgnoreCase("T") || s.equalsIgnoreCase("true")) { + return createFromLong(1); + } else if (s.equalsIgnoreCase("N") || s.equalsIgnoreCase("no") || + s.equalsIgnoreCase("F") || s.equalsIgnoreCase("false")) { + return createFromLong(0); + } else if (s.contains("e") || s.contains("E") || s.matches("-?\\d*\\.\\d*")) { + return createFromDouble(Double.parseDouble(s)); + } else if (s.matches("-?\\d+")) { + if (!CodecUtils.isGreaterThanLongMax(s)) { + return createFromLong(CodecUtils.parseLong(value)); + } + return createFromBigInteger(new BigInteger(s)); } - return createFromBigInteger(new BigInteger(s)); + throw new IllegalArgumentException("Unable to interpret string: " + s); } - throw new IllegalArgumentException("Unable to interpret string: " + s); + + return binary || dataType == MySqlType.BIT ? value.readBoolean() : value.readByte() != '0'; } @Override @@ -76,7 +82,8 @@ public MySqlParameter encode(Object value, CodecContext context) { @Override public boolean doCanDecode(MySqlReadableMetadata metadata) { MySqlType type = metadata.getType(); - return type == MySqlType.BIT || type == MySqlType.VARCHAR || type.isNumeric(); + return ((type == MySqlType.BIT || type == MySqlType.TINYINT) && + Integer.valueOf(1).equals(metadata.getPrecision())) || type == MySqlType.VARCHAR; } public Boolean createFromLong(long l) { diff --git a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/codec/BooleanCodecTest.java b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/codec/BooleanCodecTest.java index 14146710..3150bb37 100644 --- a/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/codec/BooleanCodecTest.java +++ b/r2dbc-mysql/src/test/java/io/asyncer/r2dbc/mysql/codec/BooleanCodecTest.java @@ -29,6 +29,8 @@ import org.junit.jupiter.api.Test; +import java.nio.ByteBuffer; + /** * Unit tests for {@link BooleanCodec}. */ @@ -67,18 +69,27 @@ public ByteBuf sized(ByteBuf value) { void decodeString() { Codec codec = getCodec(); Charset c = ConnectionContextTest.mock().getClientCollation().getCharset(); + byte[] bOne = new byte[]{(byte)1}; + byte[] bZero = new byte[]{(byte)0}; + ByteBuffer bitValOne = ByteBuffer.wrap(bOne); + ByteBuffer bitValZero = ByteBuffer.wrap(bZero); Decoding d1 = new Decoding(Unpooled.copiedBuffer("true", c), "true", MySqlType.VARCHAR); Decoding d2 = new Decoding(Unpooled.copiedBuffer("false", c), "false", MySqlType.VARCHAR); Decoding d3 = new Decoding(Unpooled.copiedBuffer("1", c), "1", MySqlType.VARCHAR); Decoding d4 = new Decoding(Unpooled.copiedBuffer("0", c), "0", MySqlType.VARCHAR); Decoding d5 = new Decoding(Unpooled.copiedBuffer("Y", c), "Y", MySqlType.VARCHAR); Decoding d6 = new Decoding(Unpooled.copiedBuffer("no", c), "no", MySqlType.VARCHAR); - Decoding d7 = new Decoding(Unpooled.copyDouble(26.57), 26.57, MySqlType.DOUBLE); - Decoding d8 = new Decoding(Unpooled.copyLong(-57), -57, MySqlType.TINYINT); - Decoding d9 = new Decoding(Unpooled.copyLong(100000), 100000, MySqlType.BIGINT); + Decoding d7 = new Decoding(Unpooled.copiedBuffer("26.57", c), "26.57", MySqlType.VARCHAR); + Decoding d8 = new Decoding(Unpooled.copiedBuffer("-57", c), "=57", MySqlType.VARCHAR); + Decoding d9 = new Decoding(Unpooled.copiedBuffer("100000", c), "100000", MySqlType.VARCHAR); Decoding d10 = new Decoding(Unpooled.copiedBuffer("-12345678901234567890", c), "-12345678901234567890", MySqlType.VARCHAR); Decoding d11 = new Decoding(Unpooled.copiedBuffer("Banana", c), "Banana", MySqlType.VARCHAR); + Decoding d12 = new Decoding(Unpooled.copiedBuffer(bitValOne), bitValOne, MySqlType.BIT); + Decoding d13 = new Decoding(Unpooled.copiedBuffer(bitValZero), bitValZero, MySqlType.BIT); + Decoding d14 = new Decoding(Unpooled.copyDouble(26.57d), 26.57d, MySqlType.DOUBLE); + Decoding d15 = new Decoding(Unpooled.copiedBuffer(bOne), bOne, MySqlType.TINYINT); + Decoding d16 = new Decoding(Unpooled.copiedBuffer(bZero), bZero, MySqlType.TINYINT); assertThat(codec.decode(d1.content(), d1.metadata(), Boolean.class, false, ConnectionContextTest.mock())) .as("Decode failed, %s", d1) @@ -122,5 +133,25 @@ void decodeString() { assertThatThrownBy(() -> {codec.decode(d11.content(), d11.metadata(), Boolean.class, false, ConnectionContextTest.mock());}) .isInstanceOf(IllegalArgumentException.class); + + assertThat(codec.decode(d12.content(), d12.metadata(), Boolean.class, false, ConnectionContextTest.mock())) + .as("Decode failed, %s", d12) + .isEqualTo(true); + + assertThat(codec.decode(d13.content(), d13.metadata(), Boolean.class, false, ConnectionContextTest.mock())) + .as("Decode failed, %s", d13) + .isEqualTo(false); + + assertThat(codec.decode(d14.content(), d14.metadata(), Boolean.class, false, ConnectionContextTest.mock())) + .as("Decode failed, %s", d14) + .isEqualTo(true); + + assertThat(codec.decode(d15.content(), d15.metadata(), Boolean.class, true, ConnectionContextTest.mock())) + .as("Decode failed, %s", d15) + .isEqualTo(true); + + assertThat(codec.decode(d16.content(), d16.metadata(), Boolean.class, true, ConnectionContextTest.mock())) + .as("Decode failed, %s", d14) + .isEqualTo(false); } }