Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Nov 13, 2024
1 parent 44f830e commit 802ec46
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 38 deletions.
37 changes: 17 additions & 20 deletions python/pyspark/sql/tests/test_subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,23 @@ def test_correlated_scalar_subquery(self):
),
)

with self.subTest("without .outer()"):
assertDataFrameEqual(
self.spark.table("l").select(
"a",
(
self.spark.table("r")
.where(sf.col("b") == sf.col("a").outer())
.select(sf.sum("d"))
.scalar()
.alias("sum_d")
),
),
self.spark.sql(
"""select a, (select sum(d) from r where b = l.a) sum_d from l"""
),
)

with self.subTest("in select (null safe)"):
assertDataFrameEqual(
self.spark.table("l").select(
Expand Down Expand Up @@ -470,26 +487,6 @@ def test_scalar_subquery_with_outer_reference_errors(self):
fragment="outer",
)

with self.subTest("missing `outer()` for another outer"):
with self.assertRaises(AnalysisException) as pe:
self.spark.table("l").select(
"a",
(
self.spark.table("r")
.where(sf.col("b") == sf.col("a").outer())
.select(sf.sum("d"))
.scalar()
),
).collect()

self.check_error(
exception=pe.exception,
errorClass="UNRESOLVED_COLUMN.WITH_SUGGESTION",
messageParameters={"objectName": "`b`", "proposal": "`c`, `d`"},
query_context_type=QueryContextType.DataFrame,
fragment="col",
)


class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import org.apache.spark.internal.LogKeys.EXTENDED_EXPLAIN_GENERATOR
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, Row, SparkSession}
import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{LazyOuterReference, UnsupportedOperationChecker}
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
import org.apache.spark.sql.catalyst.expressions.LazyAnalysisExpression
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command, CommandResult, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, ReturnAnswer, Union}
Expand Down Expand Up @@ -69,8 +70,9 @@ class QueryExecution(
protected def planner = sparkSession.sessionState.planner

lazy val isLazyAnalysis: Boolean = {
// Only check the main query as we can resolve LazyOuterReference inside subquery expressions.
logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyOuterReference])))
// Only check the main query as we can resolve LazyAnalysisExpression inside subquery
// expressions.
logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyAnalysisExpression])))
}

def assertAnalyzed(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,16 @@ class DataFrameSubquerySuite extends QueryTest with SharedSparkSession {
)
}

test("correlated scalar subquery in select (without .outer())") {
checkAnswer(
spark.table("l").select(
$"a",
spark.table("r").where($"b" === $"a".outer()).select(sum($"d")).as("sum_d").scalar()
),
sql("select a, (select sum(d) from r where b = l.a) sum_d from l")
)
}

test("correlated scalar subquery in select (null safe)") {
checkAnswer(
spark.table("l").select(
Expand Down Expand Up @@ -363,20 +373,5 @@ class DataFrameSubquerySuite extends QueryTest with SharedSparkSession {
queryContext =
Array(ExpectedContext(fragment = "outer", callSitePattern = getCurrentClassCallSitePattern))
)

// Missing `outer()` for another outer
val exception3 = intercept[AnalysisException] {
spark.table("l").select(
$"a",
spark.table("r").where($"b" === $"a".outer()).select(sum($"d")).scalar()
).collect()
}
checkError(
exception3,
condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
parameters = Map("objectName" -> "`b`", "proposal" -> "`c`, `d`"),
queryContext =
Array(ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern))
)
}
}

0 comments on commit 802ec46

Please sign in to comment.