Skip to content

Commit

Permalink
[BugFix] fix agg in order by not being analyzed correctly (#30108)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: packy92 <[email protected]>
  • Loading branch information
packy92 committed Sep 13, 2023
1 parent e321e85 commit 7308217
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public Boolean visitFunctionCall(FunctionCallExpr expr, Void context) {
arg.getFn() instanceof AggregateFunction, aggFunc);
return !aggFunc.isEmpty();
})) {
throw new SemanticException(PARSER_ERROR_MSG.unsupportedNestAgg("window function"), expr.getPos());
throw new SemanticException(PARSER_ERROR_MSG.unsupportedNestAgg("aggregation function"), expr.getPos());
}

if (expr.getChildren().stream().anyMatch(childExpr -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ public void analyzeIgnoreSlot(Expr expression, AnalyzeState analyzeState, Scope
bottomUpAnalyze(visitor, expression, scope);
}

public void analyzeWithoutUpdateState(Expr expression, AnalyzeState analyzeState, Scope scope) {
Visitor visitor = new Visitor(analyzeState, session) {
@Override
protected void handleResolvedField(SlotRef slot, ResolvedField resolvedField) {
// do not put the slotRef in analyzeState
}
};
bottomUpAnalyze(visitor, expression, scope);
}

private boolean isArrayHighOrderFunction(Expr expr) {
if (expr instanceof FunctionCallExpr) {
if (((FunctionCallExpr) expr).getFnName().getFunction().equals(FunctionSet.ARRAY_MAP) ||
Expand Down Expand Up @@ -352,7 +362,7 @@ public Void visitExpression(Expr node, Scope scope) {
node.getPos());
}

private void handleResolvedField(SlotRef slot, ResolvedField resolvedField) {
protected void handleResolvedField(SlotRef slot, ResolvedField resolvedField) {
analyzeState.addColumnReference(slot, FieldId.from(resolvedField));
}

Expand Down Expand Up @@ -407,7 +417,6 @@ public Void visitSlot(SlotRef node, Scope scope) {
node.resetStructInfo();
}
}

handleResolvedField(node, resolvedField);
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ private List<OrderByElement> analyzeOrderBy(List<OrderByElement> orderByElements
// but should be parsed in sourceScope
analyzeExpression(expression, analyzeState, orderByScope.getParent());
} else {
ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(session);
expressionAnalyzer.analyzeWithoutUpdateState(expression, analyzeState, orderByScope);
List<Expr> aggregations = Lists.newArrayList();
expression.collectAll(e -> e.isAggregate(), aggregations);
aggregations.forEach(e -> analyzeExpression(e, analyzeState, orderByScope.getParent()));
analyzeExpression(expression, analyzeState, orderByScope);
}

Expand Down Expand Up @@ -423,7 +428,7 @@ private void analyzeGroupingOperations(AnalyzeState analyzeState, GroupByClause
}

private List<FunctionCallExpr> analyzeAggregations(AnalyzeState analyzeState, Scope sourceScope,
List<Expr> outputAndOrderByExpressions) {
List<Expr> outputAndOrderByExpressions) {
List<FunctionCallExpr> aggregations = Lists.newArrayList();
TreeNode.collect(outputAndOrderByExpressions, Expr.isAggregatePredicate()::apply, aggregations);
aggregations.forEach(e -> analyzeExpression(e, analyzeState, sourceScope));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ public Void visitExpression(Expr node, Scope scope) {
throw unsupportedException("not yet implemented: expression analyzer for " + node.getClass().getName());
}

private void handleResolvedField(SlotRef slot, ResolvedField resolvedField) {
protected void handleResolvedField(SlotRef slot, ResolvedField resolvedField) {
analyzeState.addColumnReference(slot, FieldId.from(resolvedField));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ public void testAggregate() {
analyzeFail("select sum(v1) from t0 order by sum(abs(max(v2) over ()))",
"Unsupported nest window function inside aggregation.");
analyzeFail("select sum(v1) from t0 order by sum(max(v2))",
"Unsupported nest window function inside aggregation.");
"Unsupported nest aggregation function inside aggregation.");
analyzeFail("select sum(v1) from t0 order by sum(abs(max(v2)))",
"Unsupported nest window function inside aggregation.");
"Unsupported nest aggregation function inside aggregation.");
analyzeFail("select sum(max(v2)) from t0",
"Unsupported nest window function inside aggregation.");
"Unsupported nest aggregation function inside aggregation.");
analyzeFail("select sum(1 + max(v2)) from t0",
"Unsupported nest window function inside aggregation.");
"Unsupported nest aggregation function inside aggregation.");
analyzeFail("select min(v1) col from t0 order by min(col) + 1, min(col)",
"Column 'col' cannot be resolved");

analyzeFail("select v1 from t0 group by v1,cast(v2 as int) having cast(v2 as boolean)",
"must be an aggregate expression or appear in GROUP BY clause");
Expand Down
15 changes: 15 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2512,4 +2512,19 @@ public void testApproxTopK() throws Exception {
Assert.assertTrue(actualMessage.contains(expectedMessage));
}
}

@Test
public void testOrderByWithAgg() throws Exception {
String sql = "select round(count(t1e) * 100.0 / min(t1f), 4) as potential_customer_rate, " +
"min(t1f) as t1f, min(t1f) as t1f from test_all_type_not_null " +
"group by t1a, t1b " +
"order by round(count(t1e) * 100.0 / min(t1f) / min(t1f), 4), min(t1f), abs(t1f)";

String plan = getFragmentPlan(sql);
System.out.println(plan);
assertCContains(plan, "1:AGGREGATE (update finalize)\n" +
" | output: count(5: t1e), min(6: t1f)\n" +
" | group by: 1: t1a, 2: t1b",
"order by: <slot 14> 14: round ASC, <slot 12> 12: min ASC, <slot 15> 15: abs ASC");
}
}
76 changes: 76 additions & 0 deletions test/sql/test_agg/R/test_orderby_agg
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
-- name: test_orderby_agg
CREATE TABLE `t0` (
`v1` bigint(20) NULL COMMENT "",
`v2` bigint(20) NULL COMMENT "",
`v3` bigint(20) NULL COMMENT "",
`v4` varchar NULL COMMENT ""
) ENGINE=OLAP
DUPLICATE KEY(`v1`, `v2`, `v3`)
DISTRIBUTED BY HASH(`v1`) BUCKETS 3
PROPERTIES (
"replication_num" = "1",
"in_memory" = "false",
"enable_persistent_index" = "false",
"replicated_storage" = "true",
"compression" = "LZ4"
);
-- result:
-- !result
insert into t0 values(1, 2, 3, 'a'), (1, 3, 4, 'b'), (2, 3, 4, 'a'), (null, 1, null, 'c'), (4, null, 1 , null),
(5, 1 , 3, 'c'), (2, 2, null, 'a'), (4, null, 4, 'c'), (null, null, 2, null);
-- result:
-- !result
select min(v1) v1 from t0 group by v3 order by round(count(v2) / min(v1)), min(v1);
-- result:
None
4
2
1
1
-- !result
select min(v1) v1, round(count(v2) / min(v1)) round_col from t0 group by v3 order by abs(min(v1)) + abs(v1) asc;
-- result:
None None
1 2
1 2
2 1
4 0
-- !result
select min(v1) v1 from t0 group by v3 order by round(count(v2) / min(v1)), abs(v1);
-- result:
None
4
2
1
1
-- !result
select min(v1) v11 from t0 group by v3 order by round(count(v2) / min(v1)), abs(v11);
-- result:
None
4
2
1
1
-- !result
select min(v1) v11, min(v1) v1 from t0 group by v3 order by round(count(v2) / min(v1)), abs(v11), abs(v1);
-- result:
None None
4 4
2 2
1 1
1 1
-- !result
select round(count(v1) * 100.0 / min(v2), 4) as potential_customer_rate, min(v2) v2 from t0 group by v4 order by round(count(v1) * 100.0 / min(v2), 4), min(v2);
-- result:
None None
33.3333 3
150.0000 2
200.0000 1
-- !result
select round(count(v1) * 100.0 / min(v2), 4) as potential_customer_rate, min(v2) v2 from t0 group by v4 order by round(count(v1) * 100.0 / min(v2), 4), abs(v2);
-- result:
None None
33.3333 3
150.0000 2
200.0000 1
-- !result
29 changes: 29 additions & 0 deletions test/sql/test_agg/T/test_orderby_agg
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
-- name: test_orderby_agg

CREATE TABLE `t0` (
`v1` bigint(20) NULL COMMENT "",
`v2` bigint(20) NULL COMMENT "",
`v3` bigint(20) NULL COMMENT "",
`v4` varchar NULL COMMENT ""
) ENGINE=OLAP
DUPLICATE KEY(`v1`, `v2`, `v3`)
DISTRIBUTED BY HASH(`v1`) BUCKETS 3
PROPERTIES (
"replication_num" = "1",
"in_memory" = "false",
"enable_persistent_index" = "false",
"replicated_storage" = "true",
"compression" = "LZ4"
);

insert into t0 values(1, 2, 3, 'a'), (1, 3, 4, 'b'), (2, 3, 4, 'a'), (null, 1, null, 'c'), (4, null, 1 , null),
(5, 1 , 3, 'c'), (2, 2, null, 'a'), (4, null, 4, 'c'), (null, null, 2, null);

select min(v1) v1 from t0 group by v3 order by round(count(v2) / min(v1)), min(v1);
select min(v1) v1, round(count(v2) / min(v1)) round_col from t0 group by v3 order by abs(min(v1)) + abs(v1) asc;

select min(v1) v1 from t0 group by v3 order by round(count(v2) / min(v1)), abs(v1);
select min(v1) v11 from t0 group by v3 order by round(count(v2) / min(v1)), abs(v11);
select min(v1) v11, min(v1) v1 from t0 group by v3 order by round(count(v2) / min(v1)), abs(v11), abs(v1);
select round(count(v1) * 100.0 / min(v2), 4) as potential_customer_rate, min(v2) v2 from t0 group by v4 order by round(count(v1) * 100.0 / min(v2), 4), min(v2);
select round(count(v1) * 100.0 / min(v2), 4) as potential_customer_rate, min(v2) v2 from t0 group by v4 order by round(count(v1) * 100.0 / min(v2), 4), abs(v2);

0 comments on commit 7308217

Please sign in to comment.