Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50130][SQL][FOLLOWUP] Simplify the resolution of LazyOuterReference #48820

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions python/pyspark/sql/tests/test_subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,23 @@ def test_correlated_scalar_subquery(self):
),
)

with self.subTest("without .outer()"):
assertDataFrameEqual(
self.spark.table("l").select(
"a",
(
self.spark.table("r")
.where(sf.col("b") == sf.col("a").outer())
.select(sf.sum("d"))
.scalar()
.alias("sum_d")
),
),
self.spark.sql(
"""select a, (select sum(d) from r where b = l.a) sum_d from l"""
),
)

with self.subTest("in select (null safe)"):
assertDataFrameEqual(
self.spark.table("l").select(
Expand Down Expand Up @@ -470,26 +487,6 @@ def test_scalar_subquery_with_outer_reference_errors(self):
fragment="outer",
)

with self.subTest("missing `outer()` for another outer"):
with self.assertRaises(AnalysisException) as pe:
self.spark.table("l").select(
"a",
(
self.spark.table("r")
.where(sf.col("b") == sf.col("a").outer())
.select(sf.sum("d"))
.scalar()
),
).collect()

self.check_error(
exception=pe.exception,
errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
messageParameters={"objectName": "`b`", "proposal": "`c`, `d`"},
query_context_type=QueryContextType.DataFrame,
fragment="col",
)


class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2482,23 +2482,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
* can resolve outer references.
*
* Outer references of the subquery are updated as children of Subquery expression.
*
* If hasExplicitOuterRefs is true, the subquery should have an explicit outer reference,
* instead of common `UnresolvedAttribute`s. In this case, tries to resolve inner and outer
* references separately.
*/
private def resolveSubQuery(
e: SubqueryExpression,
outer: LogicalPlan,
hasExplicitOuterRefs: Boolean = false)(
outer: LogicalPlan)(
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
val newSubqueryPlan = if (hasExplicitOuterRefs) {
executeSameContext(e.plan).transformAllExpressionsWithPruning(
_.containsPattern(UNRESOLVED_OUTER_REFERENCE)) {
case u: UnresolvedOuterReference =>
resolveOuterReference(u.nameParts, outer).getOrElse(u)
}
} else AnalysisContext.withOuterPlan(outer) {
val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) {
executeSameContext(e.plan)
}

Expand All @@ -2523,11 +2512,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
*/
private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
case s @ ScalarSubquery(sub, _, exprId, _, _, _, _, hasExplicitOuterRefs)
if !sub.resolved =>
resolveSubQuery(s, outer, hasExplicitOuterRefs)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId, _, _, hasExplicitOuterRefs) if !sub.resolved =>
resolveSubQuery(e, outer, hasExplicitOuterRefs)(Exists(_, _, exprId))
case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved =>
resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId, _, _) if !sub.resolved =>
resolveSubQuery(e, outer)(Exists(_, _, exprId))
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, outer)((plan, exprs) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
checkOuterReference(plan, expr)

expr match {
case ScalarSubquery(query, outerAttrs, _, _, _, _, _, _) =>
case ScalarSubquery(query, outerAttrs, _, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,35 +221,39 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
val outerPlan = AnalysisContext.get.outerPlan
if (outerPlan.isEmpty) return e

e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) {
def resolve(nameParts: Seq[String]): Option[Expression] = try {
outerPlan.get match {
// Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions.
// We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will
// push them down to Aggregate later. This is similar to what we do in `resolveColumns`.
case u @ UnresolvedHaving(_, agg: Aggregate) =>
agg.resolveChildren(nameParts, conf.resolver)
.orElse(u.resolveChildren(nameParts, conf.resolver))
.map(wrapOuterReference)
case other =>
other.resolveChildren(nameParts, conf.resolver).map(wrapOuterReference)
}
} catch {
case ae: AnalysisException =>
logDebug(ae.getMessage)
None
}

e.transformWithPruning(
_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN, LAZY_OUTER_REFERENCE)) {
case u: UnresolvedAttribute =>
resolveOuterReference(u.nameParts, outerPlan.get).getOrElse(u)
resolve(u.nameParts).getOrElse(u)
case u: LazyOuterReference =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new change. Other changes are just reverting back the previous changes.

// If we can't resolve this outer reference, replace it with `UnresolvedOuterReference` so
// that we can throw a better error to mention the candidate outer columns.
resolve(u.nameParts).getOrElse(UnresolvedOuterReference(u.nameParts))
// Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with
// Aggregate but failed.
case t: TempResolvedColumn if t.hasTried =>
resolveOuterReference(t.nameParts, outerPlan.get).getOrElse(t)
resolve(t.nameParts).getOrElse(t)
}
}

protected def resolveOuterReference(
nameParts: Seq[String], outerPlan: LogicalPlan): Option[Expression] = try {
outerPlan match {
// Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions.
// We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will
// push them down to Aggregate later. This is similar to what we do in `resolveColumns`.
case u @ UnresolvedHaving(_, agg: Aggregate) =>
agg.resolveChildren(nameParts, conf.resolver)
.orElse(u.resolveChildren(nameParts, conf.resolver))
.map(wrapOuterReference)
case other =>
other.resolveChildren(nameParts, conf.resolver).map(wrapOuterReference)
}
} catch {
case ae: AnalysisException =>
logDebug(ae.getMessage)
None
}

def lookupVariable(nameParts: Seq[String]): Option[VariableReference] = {
// The temp variables live in `SYSTEM.SESSION`, and the name can be qualified or not.
def maybeTempVariableName(nameParts: Seq[String]): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,42 +1004,29 @@ case class UnresolvedTranspose(
copy(child = newChild)
}

case class UnresolvedOuterReference(
nameParts: Seq[String])
case class UnresolvedOuterReference(nameParts: Seq[String])
extends LeafExpression with NamedExpression with Unevaluable {

def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")

def name: String = nameParts.map(quoteIfNeeded).mkString(".")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

override def exprId: ExprId = throw new UnresolvedException("exprId")
override def dataType: DataType = throw new UnresolvedException("dataType")
override def nullable: Boolean = throw new UnresolvedException("nullable")
override def qualifier: Seq[String] = throw new UnresolvedException("qualifier")
override lazy val resolved = false

override def toAttribute: Attribute = throw new UnresolvedException("toAttribute")
override def newInstance(): UnresolvedOuterReference = this

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_OUTER_REFERENCE)
}

case class LazyOuterReference(
nameParts: Seq[String])
case class LazyOuterReference(nameParts: Seq[String])
extends LeafExpression with NamedExpression with Unevaluable with LazyAnalysisExpression {

def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")

def name: String = nameParts.map(quoteIfNeeded).mkString(".")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

override def exprId: ExprId = throw new UnresolvedException("exprId")
override def dataType: DataType = throw new UnresolvedException("dataType")
override def nullable: Boolean = throw new UnresolvedException("nullable")
override def qualifier: Seq[String] = throw new UnresolvedException("qualifier")

override def toAttribute: Attribute = throw new UnresolvedException("toAttribute")
override def newInstance(): NamedExpression = LazyOuterReference(nameParts)

override def nodePatternsInternal(): Seq[TreePattern] = Seq(LAZY_OUTER_REFERENCE)

override def prettyName: String = "outer"
override def sql: String = s"$prettyName($name)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@ package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.analysis.{LazyOuterReference, UnresolvedOuterReference}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -374,13 +372,6 @@ object SubExprUtils extends PredicateHelper {
val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs
nonEquivalentGroupByExprs
}

def removeLazyOuterReferences(logicalPlan: LogicalPlan): LogicalPlan = {
logicalPlan.transformAllExpressionsWithPruning(
_.containsPattern(TreePattern.LAZY_OUTER_REFERENCE)) {
case or: LazyOuterReference => UnresolvedOuterReference(or.nameParts)
}
}
}

/**
Expand All @@ -407,8 +398,7 @@ case class ScalarSubquery(
joinCond: Seq[Expression] = Seq.empty,
hint: Option[HintInfo] = None,
mayHaveCountBug: Option[Boolean] = None,
needSingleJoin: Option[Boolean] = None,
hasExplicitOuterRefs: Boolean = false)
needSingleJoin: Option[Boolean] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
override def dataType: DataType = {
if (!plan.schema.fields.nonEmpty) {
Expand Down Expand Up @@ -577,8 +567,7 @@ case class Exists(
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
joinCond: Seq[Expression] = Seq.empty,
hint: Option[HintInfo] = None,
hasExplicitOuterRefs: Boolean = false)
hint: Option[HintInfo] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint)
with Predicate
with Unevaluable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
case d: DynamicPruningSubquery => d
case s @ ScalarSubquery(
PhysicalOperation(projections, predicates, a @ Aggregate(group, _, child, _)),
_, _, _, _, mayHaveCountBug, _, _)
_, _, _, _, mayHaveCountBug, _)
if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) &&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
// This is a subquery with an aggregate that may suffer from a COUNT bug.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
}

// Don't replace ScalarSubquery if its plan is an aggregate that may suffer from a COUNT bug.
case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _, _)
case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _)
if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) &&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {

// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, _, _, conditions, subHint, _)) =>
case (p, Exists(sub, _, _, conditions, subHint)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond),
LeftSemi, joinCond, subHint)
Project(p.output, join)
case (p, Not(Exists(sub, _, _, conditions, subHint, _))) =>
case (p, Not(Exists(sub, _, _, conditions, subHint))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
val join = buildJoin(outerPlan, rewriteDomainJoinsIfPresent(outerPlan, sub, joinCond),
LeftAnti, joinCond, subHint)
Expand Down Expand Up @@ -319,7 +319,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val introducedAttrs = ArrayBuffer.empty[Attribute]
val newExprs = exprs.map { e =>
e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) {
case Exists(sub, _, _, conditions, subHint, _) =>
case Exists(sub, _, _, conditions, subHint) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
val existenceJoin = ExistenceJoin(exists)
val newCondition = conditions.reduceLeftOption(And)
Expand Down Expand Up @@ -507,7 +507,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper

plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case ScalarSubquery(sub, children, exprId, conditions, hint,
mayHaveCountBugOld, needSingleJoinOld, _)
mayHaveCountBugOld, needSingleJoinOld)
if children.nonEmpty =>

def mayHaveCountBugAgg(a: Aggregate): Boolean = {
Expand Down Expand Up @@ -560,7 +560,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
}
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions),
hint, Some(mayHaveCountBug), Some(needSingleJoin))
case Exists(sub, children, exprId, conditions, hint, _) if children.nonEmpty =>
case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty =>
val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) {
decorrelate(sub, plan, handleCountBug = true)
} else {
Expand Down Expand Up @@ -818,7 +818,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug,
needSingleJoin, _)) =>
needSingleJoin)) =>
val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
val origOutput = query.output.head
// The subquery appears on the right side of the join, hence add its hint to the right
Expand Down Expand Up @@ -1064,8 +1064,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {

case p: LogicalPlan => p.transformExpressionsUpWithPruning(
_.containsPattern(SCALAR_SUBQUERY)) {
case s @ ScalarSubquery(
OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _, _)
case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _)
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
assert(p.projectList.size == 1)
stripOuterReferences(p.projectList).head
Expand Down
8 changes: 2 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1005,16 +1005,12 @@ class Dataset[T] private[sql](

/** @inheritdoc */
def scalar(): Column = {
Column(ExpressionColumnNode(
ScalarSubqueryExpr(SubExprUtils.removeLazyOuterReferences(logicalPlan),
hasExplicitOuterRefs = true)))
Column(ExpressionColumnNode(ScalarSubqueryExpr(logicalPlan)))
}

/** @inheritdoc */
def exists(): Column = {
Column(ExpressionColumnNode(
Exists(SubExprUtils.removeLazyOuterReferences(logicalPlan),
hasExplicitOuterRefs = true)))
Column(ExpressionColumnNode(Exists(logicalPlan)))
}

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row, SparkSession}
import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
import org.apache.spark.sql.catalyst.expressions.LazyAnalysisExpression
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.trees.TreePattern.LAZY_ANALYSIS_EXPRESSION
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
Expand Down Expand Up @@ -69,7 +69,11 @@ class QueryExecution(
// TODO: Move the planner an optimizer into here from SessionState.
protected def planner = sparkSession.sessionState.planner

lazy val isLazyAnalysis: Boolean = logical.containsAnyPattern(LAZY_ANALYSIS_EXPRESSION)
lazy val isLazyAnalysis: Boolean = {
// Only check the main query as we can resolve LazyAnalysisExpression inside subquery
// expressions.
logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyAnalysisExpression])))
}

def assertAnalyzed(): Unit = {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ case class PlanAdaptiveSubqueries(
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressionsWithPruning(
_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
case expressions.ScalarSubquery(_, _, exprId, _, _, _, _, _) =>
case expressions.ScalarSubquery(_, _, exprId, _, _, _, _) =>
val subquery = SubqueryExec.createForScalarSubquery(
s"subquery#${exprId.id}", subqueryMap(exprId.id))
execution.ScalarSubquery(subquery, exprId)
Expand Down
Loading