Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: disable checking for uint_8 and uint_16 if complex type readers are enabled #1376

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,11 @@ object CometSparkSessionExtensions extends Logging {
org.apache.spark.SPARK_VERSION >= "4.0"
}

def isComplexTypeReaderEnabled(conf: SQLConf): Boolean = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the naming confusing here. This method determines if we are using native_datafusion or native_iceberg_compat (which both use DataFusion's ParquetExec). This is no logic related to complex types.

Complex type support was a big motivation for adding these new scans, but it doesn't seem to make sense to refer to complex types in the changes in this PR.

This is just a nit, and we can rename the methods in a future PR.

CometConf.COMET_NATIVE_SCAN_IMPL.get(conf) == CometConf.SCAN_NATIVE_ICEBERG_COMPAT ||
CometConf.COMET_NATIVE_SCAN_IMPL.get(conf) == CometConf.SCAN_NATIVE_DATAFUSION
}

/** Calculates required memory overhead in MB per executor process for Comet. */
def getCometMemoryOverheadInMiB(sparkConf: SparkConf): Long = {
// `spark.executor.memory` default value is 1g
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ trait DataTypeSupport {
true
case t: DataType if t.typeName == "timestamp_ntz" =>
true
true
case _ => false
}

Expand Down
21 changes: 21 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("uint data type support") {
Seq(true, false).foreach { dictionaryEnabled =>
Seq(Byte.MaxValue, Short.MaxValue).foreach { valueRanges =>
{
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "testuint.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, valueRanges + 1)
withParquetTable(path.toString, "tbl") {
if (CometSparkSessionExtensions.isComplexTypeReaderEnabled(conf)) {
checkSparkAnswer("select _9, _10 FROM tbl order by _11")
Copy link
Member

@andygrove andygrove Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we already have logic to fall back to Spark when the complex type reader is enabled and when the query references uint Parquet fields?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we don't for two reasons. Firstly, in the plan we get the schema as understood by Spark so all the signed int_8 and int_16 values are indistinguishable from the unsigned ones. As a result we fall back to Spark for both signed and unsigned integers. Secondly, too many unit tests fail because we check that the plan contains a comet operator and would need to be modified.
I'm open to putting it back though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a result we fall back to Spark for both signed and unsigned integers.

Just 8 and 16 bit, or all integers? I'm fine with falling back for 8 and 16 bit for now, although it would be nice to have a config to override this (with the understanding that behavior is incorrect for unsigned integers).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just 8 and 16 bit.
I started with the fallback to spark and a compat override. The reason I reverted it is that I couldn't see a way to get to compatibility with spark even after/if apache/arrow-rs#7040 is addressed.
Let me do as you suggest. Marking this as draft in the meantime.

} else {
checkSparkAnswerAndOperator("select _9, _10 FROM tbl order by _11")
}
}
}
}
}
}
}

test("null literals") {
val batchSize = 1000
Seq(true, false).foreach { dictionaryEnabled =>
Expand All @@ -142,6 +162,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
checkSparkAnswerAndOperator(sqlString)
}
}

}
}

Expand Down
169 changes: 117 additions & 52 deletions spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,65 +429,130 @@ abstract class CometTestBase
makeParquetFileAllTypes(path, dictionaryEnabled, 0, n)
}

def makeParquetFileAllTypes(
path: Path,
dictionaryEnabled: Boolean,
begin: Int,
end: Int,
pageSize: Int = 128,
randomSize: Int = 0): Unit = {
val schemaStr =
def getAllTypesParquetSchema: String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are renaming this method, I wonder if we should remove the AllTypes part since it does not generate all types. Perhaps getPrimitiveTypesParquetSchema?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (CometSparkSessionExtensions.isComplexTypeReaderEnabled(conf)) {
// Comet complex type reader has different behavior for uint_8, uint_16 types.
// The issue stems from undefined behavior in the parquet spec and is tracked
// here: https://github.com/apache/parquet-java/issues/3142
// here: https://github.com/apache/arrow-rs/issues/7040
// and here: https://github.com/apache/datafusion-comet/issues/1348
if (isSpark34Plus) {
"""
|message root {
| optional boolean _1;
| optional int32 _2(INT_8);
| optional int32 _3(INT_16);
| optional int32 _4;
| optional int64 _5;
| optional float _6;
| optional double _7;
| optional binary _8(UTF8);
| optional int32 _9(UINT_8);
| optional int32 _10(UINT_16);
| optional int32 _11(UINT_32);
| optional int64 _12(UINT_64);
| optional binary _13(ENUM);
| optional FIXED_LEN_BYTE_ARRAY(3) _14;
| optional int32 _15(DECIMAL(5, 2));
| optional int64 _16(DECIMAL(18, 10));
| optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37));
| optional INT64 _18(TIMESTAMP(MILLIS,true));
| optional INT64 _19(TIMESTAMP(MICROS,true));
| optional INT32 _20(DATE);
|}
|message root {
| optional boolean _1;
| optional int32 _2(INT_8);
| optional int32 _3(INT_16);
| optional int32 _4;
| optional int64 _5;
| optional float _6;
| optional double _7;
| optional binary _8(UTF8);
| optional int32 _9(UINT_32);
| optional int32 _10(UINT_32);
| optional int32 _11(UINT_32);
| optional int64 _12(UINT_64);
| optional binary _13(ENUM);
| optional FIXED_LEN_BYTE_ARRAY(3) _14;
| optional int32 _15(DECIMAL(5, 2));
| optional int64 _16(DECIMAL(18, 10));
| optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37));
| optional INT64 _18(TIMESTAMP(MILLIS,true));
| optional INT64 _19(TIMESTAMP(MICROS,true));
| optional INT32 _20(DATE);
|}
""".stripMargin
} else {
"""
|message root {
| optional boolean _1;
| optional int32 _2(INT_8);
| optional int32 _3(INT_16);
| optional int32 _4;
| optional int64 _5;
| optional float _6;
| optional double _7;
| optional binary _8(UTF8);
| optional int32 _9(UINT_8);
| optional int32 _10(UINT_16);
| optional int32 _11(UINT_32);
| optional int64 _12(UINT_64);
| optional binary _13(ENUM);
| optional binary _14(UTF8);
| optional int32 _15(DECIMAL(5, 2));
| optional int64 _16(DECIMAL(18, 10));
| optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37));
| optional INT64 _18(TIMESTAMP(MILLIS,true));
| optional INT64 _19(TIMESTAMP(MICROS,true));
| optional INT32 _20(DATE);
|}
|message root {
| optional boolean _1;
| optional int32 _2(INT_8);
| optional int32 _3(INT_16);
| optional int32 _4;
| optional int64 _5;
| optional float _6;
| optional double _7;
| optional binary _8(UTF8);
| optional int32 _9(UINT_32);
| optional int32 _10(UINT_32);
| optional int32 _11(UINT_32);
| optional int64 _12(UINT_64);
| optional binary _13(ENUM);
| optional binary _14(UTF8);
| optional int32 _15(DECIMAL(5, 2));
| optional int64 _16(DECIMAL(18, 10));
| optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37));
| optional INT64 _18(TIMESTAMP(MILLIS,true));
| optional INT64 _19(TIMESTAMP(MICROS,true));
| optional INT32 _20(DATE);
|}
""".stripMargin
}
} else {

if (isSpark34Plus) {
"""
|message root {
| optional boolean _1;
| optional int32 _2(INT_8);
| optional int32 _3(INT_16);
| optional int32 _4;
| optional int64 _5;
| optional float _6;
| optional double _7;
| optional binary _8(UTF8);
| optional int32 _9(UINT_8);
| optional int32 _10(UINT_16);
| optional int32 _11(UINT_32);
| optional int64 _12(UINT_64);
| optional binary _13(ENUM);
| optional FIXED_LEN_BYTE_ARRAY(3) _14;
| optional int32 _15(DECIMAL(5, 2));
| optional int64 _16(DECIMAL(18, 10));
| optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37));
| optional INT64 _18(TIMESTAMP(MILLIS,true));
| optional INT64 _19(TIMESTAMP(MICROS,true));
| optional INT32 _20(DATE);
|}
""".stripMargin
} else {
"""
|message root {
| optional boolean _1;
| optional int32 _2(INT_8);
| optional int32 _3(INT_16);
| optional int32 _4;
| optional int64 _5;
| optional float _6;
| optional double _7;
| optional binary _8(UTF8);
| optional int32 _9(UINT_8);
| optional int32 _10(UINT_16);
| optional int32 _11(UINT_32);
| optional int64 _12(UINT_64);
| optional binary _13(ENUM);
| optional binary _14(UTF8);
| optional int32 _15(DECIMAL(5, 2));
| optional int64 _16(DECIMAL(18, 10));
| optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37));
| optional INT64 _18(TIMESTAMP(MILLIS,true));
| optional INT64 _19(TIMESTAMP(MICROS,true));
| optional INT32 _20(DATE);
|}
""".stripMargin
}
}
}

def makeParquetFileAllTypes(
path: Path,
dictionaryEnabled: Boolean,
begin: Int,
end: Int,
pageSize: Int = 128,
randomSize: Int = 0): Unit = {
// alwaysIncludeUnsignedIntTypes means we include unsignedIntTypes in the test even if the
// reader does not support them
val schemaStr = getAllTypesParquetSchema

val schema = MessageTypeParser.parseMessageType(schemaStr)
val writer = createParquetWriter(
Expand Down
Loading