diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/constraints/Invariants.scala b/spark/src/main/scala/org/apache/spark/sql/delta/constraints/Invariants.scala index 5cd5b1e07ed..fa65fead5eb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/constraints/Invariants.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/constraints/Invariants.scala @@ -16,13 +16,15 @@ package org.apache.spark.sql.delta.constraints +import scala.collection.mutable + import org.apache.spark.sql.delta.DeltaErrors -import org.apache.spark.sql.delta.schema.SchemaUtils import org.apache.spark.sql.delta.util.JsonUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructField, StructType} + /** * List of invariants that can be defined on a Delta table that will allow us to perform @@ -71,13 +73,52 @@ object Invariants { /** Extract invariants from the given schema */ def getFromSchema(schema: StructType, spark: SparkSession): Seq[Constraint] = { - val columns = SchemaUtils.filterRecursively(schema, checkComplexTypes = false) { field => - !field.nullable || field.metadata.contains(INVARIANTS_FIELD) + /** + * Find the fields containing constraints, as well as its nearest nullable ancestor + * @return (parent path, the nearest null ancestor idx, field) + */ + def recursiveVisitSchema( + columnPath: Seq[String], + dataType: DataType, + nullableAncestorIdxs: mutable.Buffer[Int]): Seq[(Seq[String], Int, StructField)] = { + dataType match { + case st: StructType => + st.fields.toList.flatMap { field => + val includeLevel = if (field.metadata.contains(INVARIANTS_FIELD) || !field.nullable) { + Seq(( + columnPath, + if (nullableAncestorIdxs.isEmpty) -1 else nullableAncestorIdxs.last, + field + )) + } else { + Nil + } + if (field.nullable) { + nullableAncestorIdxs.append(columnPath.size) + } + val childResults = recursiveVisitSchema( + columnPath :+ field.name, field.dataType, nullableAncestorIdxs) + if (field.nullable) { + nullableAncestorIdxs.trimEnd(1) + } + includeLevel ++ childResults + } + case _ => Nil + } } - columns.map { - case (parents, field) if !field.nullable => - Constraints.NotNull(parents :+ field.name) - case (parents, field) => + + recursiveVisitSchema(Nil, schema, new mutable.ArrayBuffer[Int]()).map { + case (parents, nullableAncestor, field) if !field.nullable => + val fieldPath: Seq[String] = parents :+ field.name + if (nullableAncestor != -1) { + Constraints.Check("", + ArbitraryExpression(spark, + s"${parents.take(nullableAncestor + 1).mkString(".")} is null " + + s"or ${fieldPath.mkString(".")} is not null").expression) + } else { + Constraints.NotNull(fieldPath) + } + case (parents, _, field) => val rule = field.metadata.getString(INVARIANTS_FIELD) val invariant = Option(JsonUtils.mapper.readValue[PersistedRule](rule).unwrap) match { case Some(PersistedExpression(exprString)) => diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaColumnRenameSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaColumnRenameSuite.scala index d7735414f51..30d49739c85 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaColumnRenameSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaColumnRenameSuite.scala @@ -282,7 +282,7 @@ class DeltaColumnRenameSuite extends QueryTest "values ('str3', struct('str1.3', -1), map('k3', 'v3'), array(3, 33))") } - assertException("NOT NULL constraint violated for column: b.c1") { + assertException("CHECK constraint ((b IS NULL) OR (b.c1 IS NOT NULL)) violated") { spark.sql("insert into t1 " + "values ('str3', struct(null, 3), map('k3', 'v3'), array(3, 33))") } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaDDLSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaDDLSuite.scala index a4fd227dabf..42840f8ba44 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaDDLSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaDDLSuite.scala @@ -572,7 +572,15 @@ abstract class DeltaDDLTestBase extends QueryTest with DeltaSQLTestUtils { if (e == null) { fail("Didn't receive a InvariantViolationException.") } - assert(e.getMessage.contains("NOT NULL constraint violated for column")) + val idPattern = "[A-Za-z_][A-Za-z0-9_]*" + val idsPattern = s"$idPattern(\\.$idPattern)*" + val checkPattern = + (s"CHECK constraint \\(\\(${idsPattern} IS NULL\\) OR \\(${idsPattern} IS NOT NULL\\)\\)" + + " violated by row with values").r + assert( + e.getMessage.contains("NOT NULL constraint violated for column") || + checkPattern.findFirstIn(e.getMessage).nonEmpty + ) } test("ALTER TABLE RENAME TO") { diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/schema/InvariantEnforcementSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/schema/InvariantEnforcementSuite.scala index 7fe7e1288f7..820126412bd 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/schema/InvariantEnforcementSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/schema/InvariantEnforcementSuite.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.{CheckConstraintsTableFeature, DeltaLog, Delta import org.apache.spark.sql.delta.actions.{Metadata, TableFeatureProtocolUtils} import org.apache.spark.sql.delta.catalog.DeltaTableV2 import org.apache.spark.sql.delta.constraints.{Constraint, Constraints, Invariants} -import org.apache.spark.sql.delta.constraints.Constraints.NotNull -import org.apache.spark.sql.delta.constraints.Invariants.PersistedExpression +import org.apache.spark.sql.delta.constraints.Constraints.{Check, NotNull} +import org.apache.spark.sql.delta.constraints.Invariants.{ArbitraryExpression, PersistedExpression} import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.test.DeltaSQLCommandTest import org.apache.spark.sql.delta.test.DeltaSQLTestUtils @@ -192,17 +192,15 @@ class InvariantEnforcementSuite extends QueryTest .add("key", StringType, nullable = false) .add("value", IntegerType)) testBatchWriteRejection( - NotNull(Seq("key")), + Check("", ArbitraryExpression(spark, "top is null or top.key is not null").expression), schema, spark.createDataFrame(Seq(Row(Row("a", 1)), Row(Row(null, 2))).asJava, schema.asNullable), "top.key" ) - testBatchWriteRejection( - NotNull(Seq("key")), - schema, - spark.createDataFrame(Seq(Row(Row("a", 1)), Row(null)).asJava, schema.asNullable), - "top.key" - ) + tableWithSchema(schema) { path => + spark.createDataFrame(Seq(Row(Row("a", 1)), Row(null)).asJava, schema.asNullable) + .write.mode("append").format("delta").save(path) + } } testQuietly("reject non-nullable array column") {