Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Delta] When checks schema for writing, Delta enforces not null on a Nested Field only when its parent is not null #4121

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

package org.apache.spark.sql.delta.constraints

import scala.collection.mutable

import org.apache.spark.sql.delta.DeltaErrors
import org.apache.spark.sql.delta.schema.SchemaUtils
import org.apache.spark.sql.delta.util.JsonUtils

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructField, StructType}


/**
* List of invariants that can be defined on a Delta table that will allow us to perform
Expand Down Expand Up @@ -71,13 +73,52 @@ object Invariants {

/** Extract invariants from the given schema */
def getFromSchema(schema: StructType, spark: SparkSession): Seq[Constraint] = {
val columns = SchemaUtils.filterRecursively(schema, checkComplexTypes = false) { field =>
!field.nullable || field.metadata.contains(INVARIANTS_FIELD)
/**
* Find the fields containing constraints, as well as its nearest nullable ancestor
* @return (parent path, the nearest null ancestor idx, field)
*/
def recursiveVisitSchema(
columnPath: Seq[String],
dataType: DataType,
nullableAncestorIdxs: mutable.Buffer[Int]): Seq[(Seq[String], Int, StructField)] = {
dataType match {
case st: StructType =>
st.fields.toList.flatMap { field =>
val includeLevel = if (field.metadata.contains(INVARIANTS_FIELD) || !field.nullable) {
Seq((
columnPath,
if (nullableAncestorIdxs.isEmpty) -1 else nullableAncestorIdxs.last,
field
))
} else {
Nil
}
if (field.nullable) {
nullableAncestorIdxs.append(columnPath.size)
}
val childResults = recursiveVisitSchema(
columnPath :+ field.name, field.dataType, nullableAncestorIdxs)
if (field.nullable) {
nullableAncestorIdxs.trimEnd(1)
}
includeLevel ++ childResults
}
case _ => Nil
}
}
columns.map {
case (parents, field) if !field.nullable =>
Constraints.NotNull(parents :+ field.name)
case (parents, field) =>

recursiveVisitSchema(Nil, schema, new mutable.ArrayBuffer[Int]()).map {
case (parents, nullableAncestor, field) if !field.nullable =>
val fieldPath: Seq[String] = parents :+ field.name
if (nullableAncestor != -1) {
Constraints.Check("",
ArbitraryExpression(spark,
s"${parents.take(nullableAncestor + 1).mkString(".")} is null " +
s"or ${fieldPath.mkString(".")} is not null").expression)
} else {
Constraints.NotNull(fieldPath)
}
case (parents, _, field) =>
val rule = field.metadata.getString(INVARIANTS_FIELD)
val invariant = Option(JsonUtils.mapper.readValue[PersistedRule](rule).unwrap) match {
case Some(PersistedExpression(exprString)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ class DeltaColumnRenameSuite extends QueryTest
"values ('str3', struct('str1.3', -1), map('k3', 'v3'), array(3, 33))")
}

assertException("NOT NULL constraint violated for column: b.c1") {
assertException("CHECK constraint ((b IS NULL) OR (b.c1 IS NOT NULL)) violated") {
spark.sql("insert into t1 " +
"values ('str3', struct(null, 3), map('k3', 'v3'), array(3, 33))")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,15 @@ abstract class DeltaDDLTestBase extends QueryTest with DeltaSQLTestUtils {
if (e == null) {
fail("Didn't receive a InvariantViolationException.")
}
assert(e.getMessage.contains("NOT NULL constraint violated for column"))
val idPattern = "[A-Za-z_][A-Za-z0-9_]*"
val idsPattern = s"$idPattern(\\.$idPattern)*"
val checkPattern =
(s"CHECK constraint \\(\\(${idsPattern} IS NULL\\) OR \\(${idsPattern} IS NOT NULL\\)\\)" +
" violated by row with values").r
assert(
e.getMessage.contains("NOT NULL constraint violated for column") ||
checkPattern.findFirstIn(e.getMessage).nonEmpty
)
}

test("ALTER TABLE RENAME TO") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.{CheckConstraintsTableFeature, DeltaLog, Delta
import org.apache.spark.sql.delta.actions.{Metadata, TableFeatureProtocolUtils}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.constraints.{Constraint, Constraints, Invariants}
import org.apache.spark.sql.delta.constraints.Constraints.NotNull
import org.apache.spark.sql.delta.constraints.Invariants.PersistedExpression
import org.apache.spark.sql.delta.constraints.Constraints.{Check, NotNull}
import org.apache.spark.sql.delta.constraints.Invariants.{ArbitraryExpression, PersistedExpression}
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
import org.apache.spark.sql.delta.test.DeltaSQLTestUtils
Expand Down Expand Up @@ -192,17 +192,15 @@ class InvariantEnforcementSuite extends QueryTest
.add("key", StringType, nullable = false)
.add("value", IntegerType))
testBatchWriteRejection(
NotNull(Seq("key")),
Check("", ArbitraryExpression(spark, "top is null or top.key is not null").expression),
schema,
spark.createDataFrame(Seq(Row(Row("a", 1)), Row(Row(null, 2))).asJava, schema.asNullable),
"top.key"
)
testBatchWriteRejection(
NotNull(Seq("key")),
schema,
spark.createDataFrame(Seq(Row(Row("a", 1)), Row(null)).asJava, schema.asNullable),
"top.key"
)
tableWithSchema(schema) { path =>
spark.createDataFrame(Seq(Row(Row("a", 1)), Row(null)).asJava, schema.asNullable)
.write.mode("append").format("delta").save(path)
}
}

testQuietly("reject non-nullable array column") {
Expand Down
Loading