Skip to content

Commit

Permalink
[Enhancement] Enhance text base mv rewrite (#52498)
Browse files Browse the repository at this point in the history
Signed-off-by: shuming.li <[email protected]>
(cherry picked from commit 6029ee2)

# Conflicts:
#	fe/fe-core/src/main/java/com/starrocks/sql/optimizer/CachingMvPlanContextBuilder.java
  • Loading branch information
LiShuMing authored and mergify[bot] committed Nov 1, 2024
1 parent feafa41 commit 5b2e00b
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,19 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
<<<<<<< HEAD
=======
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.starrocks.analysis.OrderByElement;
>>>>>>> 6029ee28bc ([Enhancement] Enhance text base mv rewrite (#52498))
import com.starrocks.analysis.ParseNode;
import com.starrocks.catalog.MaterializedView;
import com.starrocks.catalog.MvPlanContext;
import com.starrocks.common.Config;
import com.starrocks.sql.analyzer.AstToSQLBuilder;
import com.starrocks.sql.ast.QueryRelation;
import com.starrocks.sql.ast.QueryStatement;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand All @@ -49,7 +57,6 @@ public static class AstKey {

/**
* Create a AstKey with parseNode(sub parse node)
* @param parseNode
*/
public AstKey(ParseNode parseNode) {
this.sql = new AstToSQLBuilder.AST2SQLBuilderVisitor(true, false, true).visit(parseNode);
Expand Down Expand Up @@ -150,15 +157,33 @@ public void invalidateFromCache(MaterializedView mv, boolean isActive) {

public void invalidateAstFromCache(MaterializedView mv) {
try {
ParseNode parseNode = mv.getDefineQueryParseNode();
if (parseNode == null) {
List<AstKey> astKeys = getAstKeysOfMV(mv);
if (CollectionUtils.isEmpty(astKeys)) {
return;
}
<<<<<<< HEAD
AstKey astKey = new AstKey(parseNode);
if (!astToMvsMap.containsKey(astKey)) {
return;
}
astToMvsMap.get(astKey).remove(mv);
=======
synchronized (AST_TO_MV_MAP) {
for (AstKey astKey : astKeys) {
if (!AST_TO_MV_MAP.containsKey(astKey)) {
continue;
}
// remove mv from ast cache
Set<MaterializedView> relatedMVs = AST_TO_MV_MAP.get(astKey);
relatedMVs.remove(mv);

// remove ast key if no related mvs
if (relatedMVs.isEmpty()) {
AST_TO_MV_MAP.remove(astKey);
}
}
}
>>>>>>> 6029ee28bc ([Enhancement] Enhance text base mv rewrite (#52498))
LOG.info("Remove mv {} from ast cache", mv.getName());
} catch (Exception e) {
LOG.warn("invalidateAstFromCache failed: {}", mv.getName(), e);
Expand All @@ -174,18 +199,54 @@ public void putAstIfAbsent(MaterializedView mv) {
}
try {
// cache by ast
ParseNode parseNode = mv.getDefineQueryParseNode();
if (parseNode == null) {
List<AstKey> astKeys = getAstKeysOfMV(mv);
if (CollectionUtils.isEmpty(astKeys)) {
return;
}
<<<<<<< HEAD
astToMvsMap.computeIfAbsent(new AstKey(parseNode), ignored -> Sets.newHashSet())
.add(mv);
=======
synchronized (AST_TO_MV_MAP) {
for (AstKey astKey : astKeys) {
AST_TO_MV_MAP.computeIfAbsent(astKey, ignored -> Sets.newHashSet()).add(mv);
}
}
>>>>>>> 6029ee28bc ([Enhancement] Enhance text base mv rewrite (#52498))
LOG.info("Add mv {} input ast cache", mv.getName());
} catch (Exception e) {
LOG.warn("putAstIfAbsent failed: {}", mv.getName(), e);
}
}

private List<AstKey> getAstKeysOfMV(MaterializedView mv) {
List<AstKey> keys = Lists.newArrayList();
ParseNode parseNode = mv.getDefineQueryParseNode();
if (parseNode == null) {
return keys;
}
// add the complete ast tree
keys.add(new AstKey(parseNode));
if (!(parseNode instanceof QueryStatement)) {
return keys;
}

// add the ast tree without an order by clause
QueryStatement queryStatement = (QueryStatement) parseNode;
QueryRelation queryRelation = queryStatement.getQueryRelation();
if (!queryRelation.hasLimit() && queryRelation.hasOrderByClause()) {
// it's fine to change query relation directly since it's not used anymore.
List<OrderByElement> orderByElements = Lists.newArrayList(queryRelation.getOrderBy());
try {
queryRelation.clearOrder();
keys.add(new AstKey(parseNode));
} finally {
queryRelation.setOrderBy(orderByElements);
}
}
return keys;
}

/**
* @param parseNode: ast to query.
* @return: null if parseNode is null or astToMvsMap doesn't contain this ast, otherwise return the mvs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ public class TextMatchBasedRewriteRule extends Rule {
// limit for mvs which matched input query(default 64)
private final long mvRewriteRelatedMVsLimit;

private int subQueryTextMatchCount = 1;

public TextMatchBasedRewriteRule(ConnectContext connectContext,
StatementBase stmt,
MVTransformerContext mvTransformerContext) {
Expand All @@ -99,10 +97,11 @@ public TextMatchBasedRewriteRule(ConnectContext connectContext,
this.connectContext = connectContext;
this.stmt = stmt;
this.mvTransformerContext = mvTransformerContext;
this.mvSubQueryTextMatchMaxCount =
connectContext.getSessionVariable().getMaterializedViewSubQueryTextMatchMaxCount();
this.mvRewriteRelatedMVsLimit =
connectContext.getSessionVariable().getCboMaterializedViewRewriteRelatedMVsLimit();
this.mvSubQueryTextMatchMaxCount =
connectContext.getSessionVariable().getMaterializedViewSubQueryTextMatchMaxCount();

}

@Override
Expand Down Expand Up @@ -360,12 +359,19 @@ private OptExpression doTextMatchBasedRewrite(OptimizerContext context,
class TextBasedRewriteVisitor extends OptExpressionVisitor<OptExpression, ConnectContext> {
private final OptimizerContext optimizerContext;
private final MVTransformerContext mvTransformerContext;
// sub-query text match count
private int subQueryTextMatchCount = 0;

public TextBasedRewriteVisitor(OptimizerContext optimizerContext,
MVTransformerContext mvTransformerContext) {
this.optimizerContext = optimizerContext;
this.mvTransformerContext = mvTransformerContext;
}

private boolean isReachLimit() {
return subQueryTextMatchCount > mvSubQueryTextMatchMaxCount;
}

private List<OptExpression> visitChildren(OptExpression optExpression, ConnectContext connectContext) {
List<OptExpression> children = com.google.common.collect.Lists.newArrayList();
for (OptExpression child : optExpression.getInputs()) {
Expand All @@ -374,17 +380,10 @@ private List<OptExpression> visitChildren(OptExpression optExpression, ConnectCo
return children;
}

private boolean isReachLimit() {
return subQueryTextMatchCount++ > mvSubQueryTextMatchMaxCount;
}

@Override
public OptExpression visit(OptExpression optExpression, ConnectContext connectContext) {
LogicalOperator op = (LogicalOperator) optExpression.getOp();
if (SUPPORTED_REWRITE_OPERATOR_TYPES.contains(op.getOpType())) {
if (isReachLimit()) {
return optExpression;
}
OptExpression rewritten = doRewrite(optExpression);
if (rewritten != null) {
return rewritten;
Expand All @@ -396,9 +395,14 @@ public OptExpression visit(OptExpression optExpression, ConnectContext connectCo

private OptExpression doRewrite(OptExpression input) {
Operator op = input.getOp();
if (!mvTransformerContext.hasOpAST(op)) {
if (!mvTransformerContext.hasOpAST(op) || isReachLimit()) {
return null;
}

// if op is in the AST map, which means op is a sub-query, then rewrite it.
subQueryTextMatchCount += 1;

// try to rewrite by text match
ParseNode parseNode = mvTransformerContext.getOpAST(op);
OptExpression rewritten = rewriteByTextMatch(input, optimizerContext,
new CachingMvPlanContextBuilder.AstKey(parseNode));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package com.starrocks.planner;

import com.starrocks.analysis.ParseNode;
import com.starrocks.sql.common.QueryDebugOptions;
import com.starrocks.sql.optimizer.CachingMvPlanContextBuilder;
import com.starrocks.sql.optimizer.rule.transformation.materialization.MvUtils;
import com.starrocks.sql.plan.PlanTestBase;
Expand All @@ -31,6 +32,9 @@ public static void beforeClass() throws Exception {
MaterializedViewTestBase.beforeClass();
connectContext.getSessionVariable().setEnableMaterializedViewTextMatchRewrite(true);
starRocksAssert.useDatabase(MATERIALIZED_DB_NAME);
QueryDebugOptions debugOptions = new QueryDebugOptions();
debugOptions.setEnableQueryTraceLog(true);
connectContext.getSessionVariable().setQueryDebugOptions(debugOptions.toString());
starRocksAssert.withTable("create table user_tags (" +
" time date, " +
" user_id int, " +
Expand Down Expand Up @@ -136,9 +140,8 @@ public void testTextMatchRewrite2() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) from user_tags group by user_id, time " +
" order by user_id, time";
// sub-query's order by will be squashed.
testRewriteFail(mv, "select * from (" + mv + ") as a;");
// TODO: support order by push-down
testRewriteFail(mv, "select * from (select user_id, time, bitmap_union(to_bitmap(tag_id)) " +
testRewriteOK(mv, "select * from (" + mv + ") as a;");
testRewriteOK(mv, "select * from (select user_id, time, bitmap_union(to_bitmap(tag_id)) " +
"from user_tags group by user_id, time) as a order by user_id, time;");
}

Expand Down Expand Up @@ -319,8 +322,7 @@ public void testTextMatchRewriteWithSubQuery3() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time order by " +
"user_id, time";
String query = String.format("select user_id, count(time) from (%s) as t group by user_id limit 3;", mv);
// TODO: support order by elimiation
testRewriteFail(mv, query);
testRewriteOK(mv, query);
}

@Test
Expand All @@ -331,12 +333,33 @@ public void testTextMatchRewriteWithSubQuery4() {
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithSubQuery5() {
String mv = "select * from (" +
" select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " +
" group by user_id, time order by user_id, time) s where user_id != 'xxxx'";
String query = String.format("with cte1 as (select * from (%s) as t) select user_id, count(time) " +
" from cte1 as t group by user_id limit 3;", mv);
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithSubQuery6() {
String mv = "select * from (" +
" select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " +
" group by user_id, time order by user_id, time) s where user_id != 'xxxx'";
String query = String.format("with cte1 as (select * from (%s) as t) select user_id, count(time) " +
" from cte1 as t group by user_id limit 3;", mv);
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithExtraOrder1() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time";
String query = String.format("select user_id from (%s) t order by user_id, time;", mv);
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithExtraOrder2() {
String mv = "select user_id, count(1) from (select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " +
Expand Down

0 comments on commit 5b2e00b

Please sign in to comment.