Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Enhance text base mv rewrite #52498

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.starrocks.analysis.OrderByElement;
import com.starrocks.analysis.ParseNode;
import com.starrocks.catalog.MaterializedView;
import com.starrocks.catalog.MvPlanContext;
import com.starrocks.common.Config;
import com.starrocks.common.FeConstants;
import com.starrocks.qe.SessionVariable;
import com.starrocks.sql.analyzer.AstToSQLBuilder;
import com.starrocks.sql.ast.QueryRelation;
import com.starrocks.sql.ast.QueryStatement;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.checkerframework.checker.nullness.qual.NonNull;
Expand Down Expand Up @@ -89,7 +93,6 @@ public static class AstKey {

/**
* Create a AstKey with parseNode(sub parse node)
* @param parseNode
*/
public AstKey(ParseNode parseNode) {
this.sql = new AstToSQLBuilder.AST2SQLBuilderVisitor(true, false, true).visit(parseNode);
Expand Down Expand Up @@ -248,15 +251,25 @@ public void updateMvPlanContextCache(MaterializedView mv, boolean isActive) {

public void invalidateAstFromCache(MaterializedView mv) {
try {
ParseNode parseNode = mv.getDefineQueryParseNode();
if (parseNode == null) {
List<AstKey> astKeys = getAstKeysOfMV(mv);
if (CollectionUtils.isEmpty(astKeys)) {
return;
}
AstKey astKey = new AstKey(parseNode);
if (!AST_TO_MV_MAP.containsKey(astKey)) {
return;
synchronized (AST_TO_MV_MAP) {
for (AstKey astKey : astKeys) {
if (!AST_TO_MV_MAP.containsKey(astKey)) {
continue;
}
// remove mv from ast cache
Set<MaterializedView> relatedMVs = AST_TO_MV_MAP.get(astKey);
relatedMVs.remove(mv);

// remove ast key if no related mvs
if (relatedMVs.isEmpty()) {
AST_TO_MV_MAP.remove(astKey);
}
}
}
AST_TO_MV_MAP.get(astKey).remove(mv);
LOG.info("Remove mv {} from ast cache", mv.getName());
} catch (Exception e) {
LOG.warn("invalidateAstFromCache failed: {}", mv.getName(), e);
Expand All @@ -272,18 +285,49 @@ public void putAstIfAbsent(MaterializedView mv) {
}
try {
// cache by ast
ParseNode parseNode = mv.getDefineQueryParseNode();
if (parseNode == null) {
List<AstKey> astKeys = getAstKeysOfMV(mv);
if (CollectionUtils.isEmpty(astKeys)) {
return;
}
AST_TO_MV_MAP.computeIfAbsent(new AstKey(parseNode), ignored -> Sets.newHashSet())
.add(mv);
synchronized (AST_TO_MV_MAP) {
for (AstKey astKey : astKeys) {
AST_TO_MV_MAP.computeIfAbsent(astKey, ignored -> Sets.newHashSet()).add(mv);
}
}
LOG.info("Add mv {} input ast cache", mv.getName());
} catch (Exception e) {
LOG.warn("putAstIfAbsent failed: {}", mv.getName(), e);
}
}

private List<AstKey> getAstKeysOfMV(MaterializedView mv) {
List<AstKey> keys = Lists.newArrayList();
ParseNode parseNode = mv.getDefineQueryParseNode();
if (parseNode == null) {
return keys;
}
// add the complete ast tree
keys.add(new AstKey(parseNode));
if (!(parseNode instanceof QueryStatement)) {
return keys;
}

// add the ast tree without an order by clause
QueryStatement queryStatement = (QueryStatement) parseNode;
QueryRelation queryRelation = queryStatement.getQueryRelation();
if (!queryRelation.hasLimit() && queryRelation.hasOrderByClause()) {
// it's fine to change query relation directly since it's not used anymore.
List<OrderByElement> orderByElements = Lists.newArrayList(queryRelation.getOrderBy());
try {
queryRelation.clearOrder();
keys.add(new AstKey(parseNode));
} finally {
queryRelation.setOrderBy(orderByElements);
}
}
return keys;
}

/**
* @return: null if parseNode is null or astToMvsMap doesn't contain this ast, otherwise return the mvs
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ public class TextMatchBasedRewriteRule extends Rule {
// limit for mvs which matched input query(default 64)
private final long mvRewriteRelatedMVsLimit;

private int subQueryTextMatchCount = 1;

public TextMatchBasedRewriteRule(ConnectContext connectContext,
StatementBase stmt,
MVTransformerContext mvTransformerContext) {
Expand All @@ -101,10 +99,11 @@ public TextMatchBasedRewriteRule(ConnectContext connectContext,
this.connectContext = connectContext;
this.stmt = stmt;
this.mvTransformerContext = mvTransformerContext;
this.mvSubQueryTextMatchMaxCount =
connectContext.getSessionVariable().getMaterializedViewSubQueryTextMatchMaxCount();
this.mvRewriteRelatedMVsLimit =
connectContext.getSessionVariable().getCboMaterializedViewRewriteRelatedMVsLimit();
this.mvSubQueryTextMatchMaxCount =
connectContext.getSessionVariable().getMaterializedViewSubQueryTextMatchMaxCount();

}

@Override
Expand Down Expand Up @@ -370,12 +369,19 @@ private OptExpression doTextMatchBasedRewrite(OptimizerContext context,
class TextBasedRewriteVisitor extends OptExpressionVisitor<OptExpression, ConnectContext> {
private final OptimizerContext optimizerContext;
private final MVTransformerContext mvTransformerContext;
// sub-query text match count
private int subQueryTextMatchCount = 0;

public TextBasedRewriteVisitor(OptimizerContext optimizerContext,
MVTransformerContext mvTransformerContext) {
this.optimizerContext = optimizerContext;
this.mvTransformerContext = mvTransformerContext;
}

private boolean isReachLimit() {
return subQueryTextMatchCount > mvSubQueryTextMatchMaxCount;
}

private List<OptExpression> visitChildren(OptExpression optExpression, ConnectContext connectContext) {
List<OptExpression> children = com.google.common.collect.Lists.newArrayList();
for (OptExpression child : optExpression.getInputs()) {
Expand All @@ -384,17 +390,10 @@ private List<OptExpression> visitChildren(OptExpression optExpression, ConnectCo
return children;
}

private boolean isReachLimit() {
return subQueryTextMatchCount++ > mvSubQueryTextMatchMaxCount;
}

@Override
public OptExpression visit(OptExpression optExpression, ConnectContext connectContext) {
LogicalOperator op = (LogicalOperator) optExpression.getOp();
if (SUPPORTED_REWRITE_OPERATOR_TYPES.contains(op.getOpType())) {
if (isReachLimit()) {
return optExpression;
}
OptExpression rewritten = doRewrite(optExpression);
if (rewritten != null) {
return rewritten;
Expand All @@ -406,9 +405,14 @@ public OptExpression visit(OptExpression optExpression, ConnectContext connectCo

private OptExpression doRewrite(OptExpression input) {
Operator op = input.getOp();
if (!mvTransformerContext.hasOpAST(op)) {
if (!mvTransformerContext.hasOpAST(op) || isReachLimit()) {
return null;
}

// if op is in the AST map, which means op is a sub-query, then rewrite it.
subQueryTextMatchCount += 1;

// try to rewrite by text match
ParseNode parseNode = mvTransformerContext.getOpAST(op);
OptExpression rewritten = rewriteByTextMatch(input, optimizerContext,
new CachingMvPlanContextBuilder.AstKey(parseNode));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package com.starrocks.planner;

import com.starrocks.analysis.ParseNode;
import com.starrocks.sql.common.QueryDebugOptions;
import com.starrocks.sql.optimizer.CachingMvPlanContextBuilder;
import com.starrocks.sql.optimizer.rule.transformation.materialization.MvUtils;
import com.starrocks.sql.plan.PlanTestBase;
Expand All @@ -31,6 +32,9 @@ public static void beforeClass() throws Exception {
MaterializedViewTestBase.beforeClass();
connectContext.getSessionVariable().setEnableMaterializedViewTextMatchRewrite(true);
starRocksAssert.useDatabase(MATERIALIZED_DB_NAME);
QueryDebugOptions debugOptions = new QueryDebugOptions();
debugOptions.setEnableQueryTraceLog(true);
connectContext.getSessionVariable().setQueryDebugOptions(debugOptions.toString());
starRocksAssert.withTable("create table user_tags (" +
" time date, " +
" user_id int, " +
Expand Down Expand Up @@ -136,9 +140,8 @@ public void testTextMatchRewrite2() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) from user_tags group by user_id, time " +
" order by user_id, time";
// sub-query's order by will be squashed.
testRewriteFail(mv, "select * from (" + mv + ") as a;");
// TODO: support order by push-down
testRewriteFail(mv, "select * from (select user_id, time, bitmap_union(to_bitmap(tag_id)) " +
testRewriteOK(mv, "select * from (" + mv + ") as a;");
testRewriteOK(mv, "select * from (select user_id, time, bitmap_union(to_bitmap(tag_id)) " +
"from user_tags group by user_id, time) as a order by user_id, time;");
}

Expand Down Expand Up @@ -319,8 +322,7 @@ public void testTextMatchRewriteWithSubQuery3() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time order by " +
"user_id, time";
String query = String.format("select user_id, count(time) from (%s) as t group by user_id limit 3;", mv);
// TODO: support order by elimiation
testRewriteFail(mv, query);
testRewriteOK(mv, query);
}

@Test
Expand All @@ -331,12 +333,33 @@ public void testTextMatchRewriteWithSubQuery4() {
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithSubQuery5() {
String mv = "select * from (" +
" select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " +
" group by user_id, time order by user_id, time) s where user_id != 'xxxx'";
String query = String.format("with cte1 as (select * from (%s) as t) select user_id, count(time) " +
" from cte1 as t group by user_id limit 3;", mv);
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithSubQuery6() {
String mv = "select * from (" +
" select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " +
" group by user_id, time order by user_id, time) s where user_id != 'xxxx'";
String query = String.format("with cte1 as (select * from (%s) as t) select user_id, count(time) " +
" from cte1 as t group by user_id limit 3;", mv);
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithExtraOrder1() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time";
String query = String.format("select user_id from (%s) t order by user_id, time;", mv);
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithExtraOrder2() {
String mv = "select user_id, count(1) from (select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " +
Expand Down
Loading