diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/CachingMvPlanContextBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/CachingMvPlanContextBuilder.java index 6ca19f573eb2b..cf55c394148cb 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/CachingMvPlanContextBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/CachingMvPlanContextBuilder.java @@ -22,6 +22,7 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import com.starrocks.analysis.OrderByElement; import com.starrocks.analysis.ParseNode; import com.starrocks.catalog.MaterializedView; import com.starrocks.catalog.MvPlanContext; @@ -29,6 +30,9 @@ import com.starrocks.common.FeConstants; import com.starrocks.qe.SessionVariable; import com.starrocks.sql.analyzer.AstToSQLBuilder; +import com.starrocks.sql.ast.QueryRelation; +import com.starrocks.sql.ast.QueryStatement; +import org.apache.commons.collections4.CollectionUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.checkerframework.checker.nullness.qual.NonNull; @@ -89,7 +93,6 @@ public static class AstKey { /** * Create a AstKey with parseNode(sub parse node) - * @param parseNode */ public AstKey(ParseNode parseNode) { this.sql = new AstToSQLBuilder.AST2SQLBuilderVisitor(true, false, true).visit(parseNode); @@ -248,15 +251,25 @@ public void updateMvPlanContextCache(MaterializedView mv, boolean isActive) { public void invalidateAstFromCache(MaterializedView mv) { try { - ParseNode parseNode = mv.getDefineQueryParseNode(); - if (parseNode == null) { + List astKeys = getAstKeysOfMV(mv); + if (CollectionUtils.isEmpty(astKeys)) { return; } - AstKey astKey = new AstKey(parseNode); - if (!AST_TO_MV_MAP.containsKey(astKey)) { - return; + synchronized (AST_TO_MV_MAP) { + for (AstKey astKey : astKeys) { + if (!AST_TO_MV_MAP.containsKey(astKey)) { + continue; + } + // remove mv from ast cache + Set relatedMVs = AST_TO_MV_MAP.get(astKey); + relatedMVs.remove(mv); + + // remove ast key if no related mvs + if (relatedMVs.isEmpty()) { + AST_TO_MV_MAP.remove(astKey); + } + } } - AST_TO_MV_MAP.get(astKey).remove(mv); LOG.info("Remove mv {} from ast cache", mv.getName()); } catch (Exception e) { LOG.warn("invalidateAstFromCache failed: {}", mv.getName(), e); @@ -272,18 +285,49 @@ public void putAstIfAbsent(MaterializedView mv) { } try { // cache by ast - ParseNode parseNode = mv.getDefineQueryParseNode(); - if (parseNode == null) { + List astKeys = getAstKeysOfMV(mv); + if (CollectionUtils.isEmpty(astKeys)) { return; } - AST_TO_MV_MAP.computeIfAbsent(new AstKey(parseNode), ignored -> Sets.newHashSet()) - .add(mv); + synchronized (AST_TO_MV_MAP) { + for (AstKey astKey : astKeys) { + AST_TO_MV_MAP.computeIfAbsent(astKey, ignored -> Sets.newHashSet()).add(mv); + } + } LOG.info("Add mv {} input ast cache", mv.getName()); } catch (Exception e) { LOG.warn("putAstIfAbsent failed: {}", mv.getName(), e); } } + private List getAstKeysOfMV(MaterializedView mv) { + List keys = Lists.newArrayList(); + ParseNode parseNode = mv.getDefineQueryParseNode(); + if (parseNode == null) { + return keys; + } + // add the complete ast tree + keys.add(new AstKey(parseNode)); + if (!(parseNode instanceof QueryStatement)) { + return keys; + } + + // add the ast tree without an order by clause + QueryStatement queryStatement = (QueryStatement) parseNode; + QueryRelation queryRelation = queryStatement.getQueryRelation(); + if (!queryRelation.hasLimit() && queryRelation.hasOrderByClause()) { + // it's fine to change query relation directly since it's not used anymore. + List orderByElements = Lists.newArrayList(queryRelation.getOrderBy()); + try { + queryRelation.clearOrder(); + keys.add(new AstKey(parseNode)); + } finally { + queryRelation.setOrderBy(orderByElements); + } + } + return keys; + } + /** * @return: null if parseNode is null or astToMvsMap doesn't contain this ast, otherwise return the mvs */ diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/materialization/rule/TextMatchBasedRewriteRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/materialization/rule/TextMatchBasedRewriteRule.java index 0251ddab7e8bb..7d5bef5af7dd2 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/materialization/rule/TextMatchBasedRewriteRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/materialization/rule/TextMatchBasedRewriteRule.java @@ -91,8 +91,6 @@ public class TextMatchBasedRewriteRule extends Rule { // limit for mvs which matched input query(default 64) private final long mvRewriteRelatedMVsLimit; - private int subQueryTextMatchCount = 1; - public TextMatchBasedRewriteRule(ConnectContext connectContext, StatementBase stmt, MVTransformerContext mvTransformerContext) { @@ -101,10 +99,11 @@ public TextMatchBasedRewriteRule(ConnectContext connectContext, this.connectContext = connectContext; this.stmt = stmt; this.mvTransformerContext = mvTransformerContext; - this.mvSubQueryTextMatchMaxCount = - connectContext.getSessionVariable().getMaterializedViewSubQueryTextMatchMaxCount(); this.mvRewriteRelatedMVsLimit = connectContext.getSessionVariable().getCboMaterializedViewRewriteRelatedMVsLimit(); + this.mvSubQueryTextMatchMaxCount = + connectContext.getSessionVariable().getMaterializedViewSubQueryTextMatchMaxCount(); + } @Override @@ -370,12 +369,19 @@ private OptExpression doTextMatchBasedRewrite(OptimizerContext context, class TextBasedRewriteVisitor extends OptExpressionVisitor { private final OptimizerContext optimizerContext; private final MVTransformerContext mvTransformerContext; + // sub-query text match count + private int subQueryTextMatchCount = 0; + public TextBasedRewriteVisitor(OptimizerContext optimizerContext, MVTransformerContext mvTransformerContext) { this.optimizerContext = optimizerContext; this.mvTransformerContext = mvTransformerContext; } + private boolean isReachLimit() { + return subQueryTextMatchCount > mvSubQueryTextMatchMaxCount; + } + private List visitChildren(OptExpression optExpression, ConnectContext connectContext) { List children = com.google.common.collect.Lists.newArrayList(); for (OptExpression child : optExpression.getInputs()) { @@ -384,17 +390,10 @@ private List visitChildren(OptExpression optExpression, ConnectCo return children; } - private boolean isReachLimit() { - return subQueryTextMatchCount++ > mvSubQueryTextMatchMaxCount; - } - @Override public OptExpression visit(OptExpression optExpression, ConnectContext connectContext) { LogicalOperator op = (LogicalOperator) optExpression.getOp(); if (SUPPORTED_REWRITE_OPERATOR_TYPES.contains(op.getOpType())) { - if (isReachLimit()) { - return optExpression; - } OptExpression rewritten = doRewrite(optExpression); if (rewritten != null) { return rewritten; @@ -406,9 +405,14 @@ public OptExpression visit(OptExpression optExpression, ConnectContext connectCo private OptExpression doRewrite(OptExpression input) { Operator op = input.getOp(); - if (!mvTransformerContext.hasOpAST(op)) { + if (!mvTransformerContext.hasOpAST(op) || isReachLimit()) { return null; } + + // if op is in the AST map, which means op is a sub-query, then rewrite it. + subQueryTextMatchCount += 1; + + // try to rewrite by text match ParseNode parseNode = mvTransformerContext.getOpAST(op); OptExpression rewritten = rewriteByTextMatch(input, optimizerContext, new CachingMvPlanContextBuilder.AstKey(parseNode)); diff --git a/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTextBasedRewriteTest.java b/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTextBasedRewriteTest.java index 2fba936f4ca66..9ec5f3b3ba3d5 100644 --- a/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTextBasedRewriteTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTextBasedRewriteTest.java @@ -15,6 +15,7 @@ package com.starrocks.planner; import com.starrocks.analysis.ParseNode; +import com.starrocks.sql.common.QueryDebugOptions; import com.starrocks.sql.optimizer.CachingMvPlanContextBuilder; import com.starrocks.sql.optimizer.rule.transformation.materialization.MvUtils; import com.starrocks.sql.plan.PlanTestBase; @@ -31,6 +32,9 @@ public static void beforeClass() throws Exception { MaterializedViewTestBase.beforeClass(); connectContext.getSessionVariable().setEnableMaterializedViewTextMatchRewrite(true); starRocksAssert.useDatabase(MATERIALIZED_DB_NAME); + QueryDebugOptions debugOptions = new QueryDebugOptions(); + debugOptions.setEnableQueryTraceLog(true); + connectContext.getSessionVariable().setQueryDebugOptions(debugOptions.toString()); starRocksAssert.withTable("create table user_tags (" + " time date, " + " user_id int, " + @@ -136,9 +140,8 @@ public void testTextMatchRewrite2() { String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) from user_tags group by user_id, time " + " order by user_id, time"; // sub-query's order by will be squashed. - testRewriteFail(mv, "select * from (" + mv + ") as a;"); - // TODO: support order by push-down - testRewriteFail(mv, "select * from (select user_id, time, bitmap_union(to_bitmap(tag_id)) " + + testRewriteOK(mv, "select * from (" + mv + ") as a;"); + testRewriteOK(mv, "select * from (select user_id, time, bitmap_union(to_bitmap(tag_id)) " + "from user_tags group by user_id, time) as a order by user_id, time;"); } @@ -319,8 +322,7 @@ public void testTextMatchRewriteWithSubQuery3() { String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time order by " + "user_id, time"; String query = String.format("select user_id, count(time) from (%s) as t group by user_id limit 3;", mv); - // TODO: support order by elimiation - testRewriteFail(mv, query); + testRewriteOK(mv, query); } @Test @@ -331,12 +333,33 @@ public void testTextMatchRewriteWithSubQuery4() { testRewriteOK(mv, query); } + @Test + public void testTextMatchRewriteWithSubQuery5() { + String mv = "select * from (" + + " select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " + + " group by user_id, time order by user_id, time) s where user_id != 'xxxx'"; + String query = String.format("with cte1 as (select * from (%s) as t) select user_id, count(time) " + + " from cte1 as t group by user_id limit 3;", mv); + testRewriteOK(mv, query); + } + + @Test + public void testTextMatchRewriteWithSubQuery6() { + String mv = "select * from (" + + " select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " + + " group by user_id, time order by user_id, time) s where user_id != 'xxxx'"; + String query = String.format("with cte1 as (select * from (%s) as t) select user_id, count(time) " + + " from cte1 as t group by user_id limit 3;", mv); + testRewriteOK(mv, query); + } + @Test public void testTextMatchRewriteWithExtraOrder1() { String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time"; String query = String.format("select user_id from (%s) t order by user_id, time;", mv); testRewriteOK(mv, query); } + @Test public void testTextMatchRewriteWithExtraOrder2() { String mv = "select user_id, count(1) from (select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " +