From 44f830eb7c8e0b8134647f55aaea9a8e4f88b4c2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 12 Nov 2024 13:27:35 +0800 Subject: [PATCH 1/2] Simplify the resolution of LazyOuterReference --- .../sql/catalyst/analysis/Analyzer.scala | 24 +++------- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../analysis/ColumnResolutionHelper.scala | 48 ++++++++++--------- .../sql/catalyst/analysis/unresolved.scala | 21 ++------ .../sql/catalyst/expressions/subquery.scala | 15 +----- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../sql/catalyst/optimizer/subquery.scala | 15 +++--- .../scala/org/apache/spark/sql/Dataset.scala | 8 +--- .../spark/sql/execution/QueryExecution.scala | 8 ++-- .../adaptive/PlanAdaptiveSubqueries.scala | 2 +- 11 files changed, 56 insertions(+), 91 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d1d04d4117263..c30f53304ab6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2482,23 +2482,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * can resolve outer references. * * Outer references of the subquery are updated as children of Subquery expression. - * - * If hasExplicitOuterRefs is true, the subquery should have an explicit outer reference, - * instead of common `UnresolvedAttribute`s. In this case, tries to resolve inner and outer - * references separately. */ private def resolveSubQuery( e: SubqueryExpression, - outer: LogicalPlan, - hasExplicitOuterRefs: Boolean = false)( + outer: LogicalPlan)( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { - val newSubqueryPlan = if (hasExplicitOuterRefs) { - executeSameContext(e.plan).transformAllExpressionsWithPruning( - _.containsPattern(UNRESOLVED_OUTER_REFERENCE)) { - case u: UnresolvedOuterReference => - resolveOuterReference(u.nameParts, outer).getOrElse(u) - } - } else AnalysisContext.withOuterPlan(outer) { + val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) { executeSameContext(e.plan) } @@ -2523,11 +2512,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { - case s @ ScalarSubquery(sub, _, exprId, _, _, _, _, hasExplicitOuterRefs) - if !sub.resolved => - resolveSubQuery(s, outer, hasExplicitOuterRefs)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, _, exprId, _, _, hasExplicitOuterRefs) if !sub.resolved => - resolveSubQuery(e, outer, hasExplicitOuterRefs)(Exists(_, _, exprId)) + case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved => + resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId)) + case e @ Exists(sub, _, exprId, _, _) if !sub.resolved => + resolveSubQuery(e, outer)(Exists(_, _, exprId)) case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _)) if values.forall(_.resolved) && !l.resolved => val expr = resolveSubQuery(l, outer)((plan, exprs) => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c7e5fa9f2b6c6..75e8d8ad9923b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1088,7 +1088,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkOuterReference(plan, expr) expr match { - case ScalarSubquery(query, outerAttrs, _, _, _, _, _, _) => + case ScalarSubquery(query, outerAttrs, _, _, _, _, _) => // Scalar subquery must return one column as output. if (query.output.size != 1) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index e869cb281ce05..828d161115877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -221,35 +221,39 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { val outerPlan = AnalysisContext.get.outerPlan if (outerPlan.isEmpty) return e - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { + def resolve(nameParts: Seq[String]): Option[Expression] = try { + outerPlan.get match { + // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. + // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will + // push them down to Aggregate later. This is similar to what we do in `resolveColumns`. + case u @ UnresolvedHaving(_, agg: Aggregate) => + agg.resolveChildren(nameParts, conf.resolver) + .orElse(u.resolveChildren(nameParts, conf.resolver)) + .map(wrapOuterReference) + case other => + other.resolveChildren(nameParts, conf.resolver).map(wrapOuterReference) + } + } catch { + case ae: AnalysisException => + logDebug(ae.getMessage) + None + } + + e.transformWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN, LAZY_OUTER_REFERENCE)) { case u: UnresolvedAttribute => - resolveOuterReference(u.nameParts, outerPlan.get).getOrElse(u) + resolve(u.nameParts).getOrElse(u) + case u: LazyOuterReference => + // If we can't resolve this outer reference, replace it with `UnresolvedOuterReference` so + // that we can throw a better error to mention the candidate outer columns. + resolve(u.nameParts).getOrElse(UnresolvedOuterReference(u.nameParts)) // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with // Aggregate but failed. case t: TempResolvedColumn if t.hasTried => - resolveOuterReference(t.nameParts, outerPlan.get).getOrElse(t) + resolve(t.nameParts).getOrElse(t) } } - protected def resolveOuterReference( - nameParts: Seq[String], outerPlan: LogicalPlan): Option[Expression] = try { - outerPlan match { - // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. - // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will - // push them down to Aggregate later. This is similar to what we do in `resolveColumns`. - case u @ UnresolvedHaving(_, agg: Aggregate) => - agg.resolveChildren(nameParts, conf.resolver) - .orElse(u.resolveChildren(nameParts, conf.resolver)) - .map(wrapOuterReference) - case other => - other.resolveChildren(nameParts, conf.resolver).map(wrapOuterReference) - } - } catch { - case ae: AnalysisException => - logDebug(ae.getMessage) - None - } - def lookupVariable(nameParts: Seq[String]): Option[VariableReference] = { // The temp variables live in `SYSTEM.SESSION`, and the name can be qualified or not. def maybeTempVariableName(nameParts: Seq[String]): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 40994f42e71d6..157baac73c073 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -1004,42 +1004,29 @@ case class UnresolvedTranspose( copy(child = newChild) } -case class UnresolvedOuterReference( - nameParts: Seq[String]) +case class UnresolvedOuterReference(nameParts: Seq[String]) extends LeafExpression with NamedExpression with Unevaluable { - - def name: String = - nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") - + def name: String = nameParts.map(quoteIfNeeded).mkString(".") override def exprId: ExprId = throw new UnresolvedException("exprId") override def dataType: DataType = throw new UnresolvedException("dataType") override def nullable: Boolean = throw new UnresolvedException("nullable") override def qualifier: Seq[String] = throw new UnresolvedException("qualifier") override lazy val resolved = false - override def toAttribute: Attribute = throw new UnresolvedException("toAttribute") override def newInstance(): UnresolvedOuterReference = this - final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_OUTER_REFERENCE) } -case class LazyOuterReference( - nameParts: Seq[String]) +case class LazyOuterReference(nameParts: Seq[String]) extends LeafExpression with NamedExpression with Unevaluable with LazyAnalysisExpression { - - def name: String = - nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") - + def name: String = nameParts.map(quoteIfNeeded).mkString(".") override def exprId: ExprId = throw new UnresolvedException("exprId") override def dataType: DataType = throw new UnresolvedException("dataType") override def nullable: Boolean = throw new UnresolvedException("nullable") override def qualifier: Seq[String] = throw new UnresolvedException("qualifier") - override def toAttribute: Attribute = throw new UnresolvedException("toAttribute") override def newInstance(): NamedExpression = LazyOuterReference(nameParts) - override def nodePatternsInternal(): Seq[TreePattern] = Seq(LAZY_OUTER_REFERENCE) - override def prettyName: String = "outer" override def sql: String = s"$prettyName($name)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index bd6f65b61468d..0c8253659dd56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,11 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.analysis.{LazyOuterReference, UnresolvedOuterReference} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -374,13 +372,6 @@ object SubExprUtils extends PredicateHelper { val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs nonEquivalentGroupByExprs } - - def removeLazyOuterReferences(logicalPlan: LogicalPlan): LogicalPlan = { - logicalPlan.transformAllExpressionsWithPruning( - _.containsPattern(TreePattern.LAZY_OUTER_REFERENCE)) { - case or: LazyOuterReference => UnresolvedOuterReference(or.nameParts) - } - } } /** @@ -407,8 +398,7 @@ case class ScalarSubquery( joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None, mayHaveCountBug: Option[Boolean] = None, - needSingleJoin: Option[Boolean] = None, - hasExplicitOuterRefs: Boolean = false) + needSingleJoin: Option[Boolean] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = { if (!plan.schema.fields.nonEmpty) { @@ -577,8 +567,7 @@ case class Exists( outerAttrs: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, - hint: Option[HintInfo] = None, - hasExplicitOuterRefs: Boolean = false) + hint: Option[HintInfo] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Predicate with Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 90d9bd5d5d88e..9a2aa82c25d51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -346,7 +346,7 @@ abstract class Optimizer(catalogManager: CatalogManager) case d: DynamicPruningSubquery => d case s @ ScalarSubquery( PhysicalOperation(projections, predicates, a @ Aggregate(group, _, child, _)), - _, _, _, _, mayHaveCountBug, _, _) + _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => // This is a subquery with an aggregate that may suffer from a COUNT bug. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 3eb7eb6e6b2e8..c9f0cfb6b4668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] { } // Don't replace ScalarSubquery if its plan is an aggregate that may suffer from a COUNT bug. - case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _, _) + case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 8c82769dbf4a3..5a4e9f37c3951 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -131,12 +131,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, Exists(sub, _, _, conditions, subHint, _)) => + case (p, Exists(sub, _, _, conditions, subHint)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond), LeftSemi, joinCond, subHint) Project(p.output, join) - case (p, Not(Exists(sub, _, _, conditions, subHint, _))) => + case (p, Not(Exists(sub, _, _, conditions, subHint))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond), LeftAnti, joinCond, subHint) @@ -319,7 +319,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val introducedAttrs = ArrayBuffer.empty[Attribute] val newExprs = exprs.map { e => e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) { - case Exists(sub, _, _, conditions, subHint, _) => + case Exists(sub, _, _, conditions, subHint) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val existenceJoin = ExistenceJoin(exists) val newCondition = conditions.reduceLeftOption(And) @@ -507,7 +507,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case ScalarSubquery(sub, children, exprId, conditions, hint, - mayHaveCountBugOld, needSingleJoinOld, _) + mayHaveCountBugOld, needSingleJoinOld) if children.nonEmpty => def mayHaveCountBugAgg(a: Aggregate): Boolean = { @@ -560,7 +560,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint, Some(mayHaveCountBug), Some(needSingleJoin)) - case Exists(sub, children, exprId, conditions, hint, _) if children.nonEmpty => + case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty => val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) { decorrelate(sub, plan, handleCountBug = true) } else { @@ -818,7 +818,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug, - needSingleJoin, _)) => + needSingleJoin)) => val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head // The subquery appears on the right side of the join, hence add its hint to the right @@ -1064,8 +1064,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { - case s @ ScalarSubquery( - OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _, _) + case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _) if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => assert(p.projectList.size == 1) stripOuterReferences(p.projectList).head diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 500a4c7c4d9bc..15acdd30445e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1005,16 +1005,12 @@ class Dataset[T] private[sql]( /** @inheritdoc */ def scalar(): Column = { - Column(ExpressionColumnNode( - ScalarSubqueryExpr(SubExprUtils.removeLazyOuterReferences(logicalPlan), - hasExplicitOuterRefs = true))) + Column(ExpressionColumnNode(ScalarSubqueryExpr(logicalPlan))) } /** @inheritdoc */ def exists(): Column = { - Column(ExpressionColumnNode( - Exists(SubExprUtils.removeLazyOuterReferences(logicalPlan), - hasExplicitOuterRefs = true))) + Column(ExpressionColumnNode(Exists(logicalPlan))) } /** @inheritdoc */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 490184c93620a..d54f58a039f2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -31,12 +31,11 @@ import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker +import org.apache.spark.sql.catalyst.analysis.{LazyOuterReference, UnsupportedOperationChecker} import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} -import org.apache.spark.sql.catalyst.trees.TreePattern.LAZY_ANALYSIS_EXPRESSION import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} @@ -69,7 +68,10 @@ class QueryExecution( // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner - lazy val isLazyAnalysis: Boolean = logical.containsAnyPattern(LAZY_ANALYSIS_EXPRESSION) + lazy val isLazyAnalysis: Boolean = { + // Only check the main query as we can resolve LazyOuterReference inside subquery expressions. + logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyOuterReference]))) + } def assertAnalyzed(): Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index 35a815d83922d..5f2638655c37c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,7 +30,7 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _, _, _, _, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) From 802ec466039d6333dfce7e634e3c1fdbc8c951cc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 13 Nov 2024 18:47:02 +0800 Subject: [PATCH 2/2] fix --- python/pyspark/sql/tests/test_subquery.py | 37 +++++++++---------- .../spark/sql/execution/QueryExecution.scala | 8 ++-- .../spark/sql/DataFrameSubquerySuite.scala | 25 +++++-------- 3 files changed, 32 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/tests/test_subquery.py b/python/pyspark/sql/tests/test_subquery.py index f58ff6364aed7..43fde68a8e5ec 100644 --- a/python/pyspark/sql/tests/test_subquery.py +++ b/python/pyspark/sql/tests/test_subquery.py @@ -237,6 +237,23 @@ def test_correlated_scalar_subquery(self): ), ) + with self.subTest("without .outer()"): + assertDataFrameEqual( + self.spark.table("l").select( + "a", + ( + self.spark.table("r") + .where(sf.col("b") == sf.col("a").outer()) + .select(sf.sum("d")) + .scalar() + .alias("sum_d") + ), + ), + self.spark.sql( + """select a, (select sum(d) from r where b = l.a) sum_d from l""" + ), + ) + with self.subTest("in select (null safe)"): assertDataFrameEqual( self.spark.table("l").select( @@ -470,26 +487,6 @@ def test_scalar_subquery_with_outer_reference_errors(self): fragment="outer", ) - with self.subTest("missing `outer()` for another outer"): - with self.assertRaises(AnalysisException) as pe: - self.spark.table("l").select( - "a", - ( - self.spark.table("r") - .where(sf.col("b") == sf.col("a").outer()) - .select(sf.sum("d")) - .scalar() - ), - ).collect() - - self.check_error( - exception=pe.exception, - errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION", - messageParameters={"objectName": "`b`", "proposal": "`c`, `d`"}, - query_context_type=QueryContextType.DataFrame, - fragment="col", - ) - class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index d54f58a039f2c..8e003e09032dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -31,7 +31,8 @@ import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} -import org.apache.spark.sql.catalyst.analysis.{LazyOuterReference, UnsupportedOperationChecker} +import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker +import org.apache.spark.sql.catalyst.expressions.LazyAnalysisExpression import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union} @@ -69,8 +70,9 @@ class QueryExecution( protected def planner = sparkSession.sessionState.planner lazy val isLazyAnalysis: Boolean = { - // Only check the main query as we can resolve LazyOuterReference inside subquery expressions. - logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyOuterReference]))) + // Only check the main query as we can resolve LazyAnalysisExpression inside subquery + // expressions. + logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyAnalysisExpression]))) } def assertAnalyzed(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala index 5a065d7e73b1c..f39926dec40fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala @@ -256,6 +256,16 @@ class DataFrameSubquerySuite extends QueryTest with SharedSparkSession { ) } + test("correlated scalar subquery in select (without .outer())") { + checkAnswer( + spark.table("l").select( + $"a", + spark.table("r").where($"b" === $"a".outer()).select(sum($"d")).as("sum_d").scalar() + ), + sql("select a, (select sum(d) from r where b = l.a) sum_d from l") + ) + } + test("correlated scalar subquery in select (null safe)") { checkAnswer( spark.table("l").select( @@ -363,20 +373,5 @@ class DataFrameSubquerySuite extends QueryTest with SharedSparkSession { queryContext = Array(ExpectedContext(fragment = "outer", callSitePattern = getCurrentClassCallSitePattern)) ) - - // Missing `outer()` for another outer - val exception3 = intercept[AnalysisException] { - spark.table("l").select( - $"a", - spark.table("r").where($"b" === $"a".outer()).select(sum($"d")).scalar() - ).collect() - } - checkError( - exception3, - condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`b`", "proposal" -> "`c`, `d`"), - queryContext = - Array(ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) - ) } }