Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(interactive): Support ScanEarlyStopRule for Query Optimization #4431

Merged
merged 14 commits into from
Jan 22, 2025
19 changes: 18 additions & 1 deletion docs/interactive_engine/gopt.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,21 @@ Design of GOpt
:::

### Detailed Introduction
A comprehensive introduction to GOpt will be provided in subsequent sections. Please stay tuned for detailed updates and information.

#### Rules

Rules play a pivotal role in GOpt’s optimization framework, enabling efficient and effective query transformations. Below is an outline of some key rules implemented in GOpt:

**ScanEarlyStopRule**: Pushes the limit operation down to the scan node. During the scan process, the scan stops as soon as the specified limit count is reached.

**ScanExpandFusionRule**: This rule transforms edge expansion into edge scan wherever possible. For example, consider the following Cypher query:
```cypher
Match (a:PERSON)-[b:KNOWS]->(c:PERSON) Return b.name;
```
Although the query involves Scan and GetV steps, their results are not directly utilized by subsequent project operations. The only effectively used data is the edge data produced by the Expand operation. In such cases, we can perform a fusion operation, transforming the pattern
`(a:PERSON)-[b:KNOWS]->(c:PERSON)` into a scan operation on the KNOWS edge. It is important to note that whether fusion is feasible also depends on the label dependencies between nodes and edges. If the edge label is determined strictly by the triplet (src_label, edge_label, dst_label), fusion cannot be performed. For example, consider the following query:
```cypher
Match (a:PERSON)-[b:LIKES]->(c:COMMENT) Return b.name;
```

**TopKPushDownRule**: This rule pushes down topK operations to the project node and is based on Calcite's [SortProjectTransposeRule](https://calcite.apache.org/javadocAggregate/org/apache/calcite/rel/rules/SortProjectTransposeRule.html), leveraging the original rule wherever possible. However, in our more complex distributed scenario, deferring the execution of the project node can disrupt already sorted data. To address this, we modified the matching logic in `SortProjectTransposeRule`. Currently, the PushDown operation is applied only when the sort fields are empty, which means only the limit is pushed down to the project node.
41 changes: 3 additions & 38 deletions flex/engines/graph_db/runtime/common/operators/retrieve/scan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace gs {
namespace runtime {

Context Scan::find_vertex_with_oid(const GraphReadInterface& graph,
label_t label, const Any& oid, int alias) {
label_t label, const Any& oid,
int32_t alias) {
SLVertexColumnBuilder builder(label);
vid_t vid;
if (graph.GetVertexIndex(label, oid, vid)) {
Expand All @@ -31,7 +32,7 @@ Context Scan::find_vertex_with_oid(const GraphReadInterface& graph,
}

Context Scan::find_vertex_with_gid(const GraphReadInterface& graph,
label_t label, int64_t gid, int alias) {
label_t label, int64_t gid, int32_t alias) {
SLVertexColumnBuilder builder(label);
if (GlobalId::get_label_id(gid) == label) {
builder.push_back_opt(GlobalId::get_vid(gid));
Expand All @@ -44,42 +45,6 @@ Context Scan::find_vertex_with_gid(const GraphReadInterface& graph,
return ctx;
}

Context Scan::find_vertex_with_id(const GraphReadInterface& graph,
label_t label, const Any& pk, int alias,
bool scan_oid) {
if (scan_oid) {
SLVertexColumnBuilder builder(label);
vid_t vid;
if (graph.GetVertexIndex(label, pk, vid)) {
builder.push_back_opt(vid);
}
Context ctx;
ctx.set(alias, builder.finish());
return ctx;
} else {
SLVertexColumnBuilder builder(label);
vid_t vid{};
int64_t gid{};
if (pk.type == PropertyType::kInt64) {
gid = pk.AsInt64();
} else if (pk.type == PropertyType::kInt32) {
gid = pk.AsInt32();
} else {
LOG(FATAL) << "Unsupported primary key type";
}
if (GlobalId::get_label_id(gid) == label) {
vid = GlobalId::get_vid(gid);
} else {
LOG(ERROR) << "Global id " << gid << " does not match label " << label;
return Context();
}
builder.push_back_opt(vid);
Context ctx;
ctx.set(alias, builder.finish());
return ctx;
}
}

template <typename T>
static Context _scan_vertex_with_special_vertex_predicate(
const GraphReadInterface& graph, const ScanParams& params,
Expand Down
80 changes: 74 additions & 6 deletions flex/engines/graph_db/runtime/common/operators/retrieve/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
struct ScanParams {
int alias;
std::vector<label_t> tables;
int32_t limit;

ScanParams() : alias(-1), limit(std::numeric_limits<int32_t>::max()) {}
};
class Scan {
public:
Expand Down Expand Up @@ -61,6 +64,50 @@
return ctx;
}

template <typename PRED_T>
static Context scan_vertex_with_limit(const GraphReadInterface& graph,
const ScanParams& params,
const PRED_T& predicate) {
Context ctx;
int32_t cur_limit = params.limit;
if (params.tables.size() == 1) {
label_t label = params.tables[0];
SLVertexColumnBuilder builder(label);
auto vertices = graph.GetVertexSet(label);
for (auto vid : vertices) {
if (cur_limit <= 0) {
break;
}
if (predicate(label, vid)) {
builder.push_back_opt(vid);
cur_limit--;
}
}
ctx.set(params.alias, builder.finish());
} else if (params.tables.size() > 1) {
MSVertexColumnBuilder builder;

for (auto label : params.tables) {
if (cur_limit <= 0) {
break;
}
auto vertices = graph.GetVertexSet(label);
builder.start_label(label);
for (auto vid : vertices) {
if (cur_limit <= 0) {
break;
}
if (predicate(label, vid)) {
builder.push_back_opt(vid);
cur_limit--;
}
}
}
ctx.set(params.alias, builder.finish());
}
return ctx;
}

static Context scan_vertex_with_special_vertex_predicate(
const GraphReadInterface& graph, const ScanParams& params,
const SPVertexPredicate& pred);
Expand All @@ -70,65 +117,89 @@
const ScanParams& params, const PRED_T& predicate,
const std::vector<int64_t>& gids) {
Context ctx;
int32_t cur_limit = params.limit;
if (params.tables.size() == 1) {
label_t label = params.tables[0];
SLVertexColumnBuilder builder(label);
for (auto gid : gids) {
if (cur_limit <= 0) {
break;
}
vid_t vid = GlobalId::get_vid(gid);
if (GlobalId::get_label_id(gid) == label && predicate(label, vid)) {
builder.push_back_opt(vid);
cur_limit--;
}
}
ctx.set(params.alias, builder.finish());
} else if (params.tables.size() > 1) {
MLVertexColumnBuilder builder;

for (auto label : params.tables) {
if (cur_limit <= 0) {
break;
}
for (auto gid : gids) {
if (cur_limit <= 0) {
break;
}
vid_t vid = GlobalId::get_vid(gid);
if (GlobalId::get_label_id(gid) == label && predicate(label, vid)) {
builder.push_back_vertex({label, vid});
cur_limit--;
}
}
}
ctx.set(params.alias, builder.finish());
}
return ctx;
}

static Context filter_gids_with_special_vertex_predicate(
const GraphReadInterface& graph, const ScanParams& params,
const SPVertexPredicate& predicate, const std::vector<int64_t>& oids);

template <typename PRED_T>
static Context filter_oids(const GraphReadInterface& graph,
const ScanParams& params, const PRED_T& predicate,
const std::vector<Any>& oids) {
Context ctx;
auto limit = params.limit;
if (params.tables.size() == 1) {
label_t label = params.tables[0];
SLVertexColumnBuilder builder(label);
for (auto oid : oids) {
if (limit <= 0) {
break;
}
vid_t vid;
if (graph.GetVertexIndex(label, oid, vid)) {
if (predicate(label, vid)) {
builder.push_back_opt(vid);
--limit;
}
}
}
ctx.set(params.alias, builder.finish());
} else if (params.tables.size() > 1) {
std::vector<std::pair<label_t, vid_t>> vids;

for (auto label : params.tables) {
if (limit <= 0) {
break;
}
for (auto oid : oids) {
if (limit <= 0) {
break;
}
vid_t vid;
if (graph.GetVertexIndex(label, oid, vid)) {
if (predicate(label, vid)) {
vids.emplace_back(label, vid);
--limit;
}
}
}

Check notice on line 202 in flex/engines/graph_db/runtime/common/operators/retrieve/scan.h

View check run for this annotation

codefactor.io / CodeFactor

flex/engines/graph_db/runtime/common/operators/retrieve/scan.h#L131-L202

Complex Method
}
if (vids.size() == 1) {
SLVertexColumnBuilder builder(vids[0].first);
Expand All @@ -149,15 +220,12 @@
const GraphReadInterface& graph, const ScanParams& params,
const SPVertexPredicate& predicate, const std::vector<Any>& oids);

static Context find_vertex_with_id(const GraphReadInterface& graph,
label_t label, const Any& pk, int alias,
bool scan_oid);

static Context find_vertex_with_oid(const GraphReadInterface& graph,
label_t label, const Any& pk, int alias);
label_t label, const Any& pk,
int32_t alias);

static Context find_vertex_with_gid(const GraphReadInterface& graph,
label_t label, int64_t pk, int alias);
label_t label, int64_t pk, int32_t alias);
};

} // namespace runtime
Expand Down
54 changes: 42 additions & 12 deletions flex/engines/graph_db/runtime/execute/ops/retrieve/scan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,17 +495,31 @@
auto expr =
parse_expression(graph, tmp, params, pred_, VarType::kVertexVar);
if (expr->is_optional()) {
auto ret = Scan::scan_vertex(
graph, scan_params_, [&expr](label_t label, vid_t vid) {
return expr->eval_vertex(label, vid, 0, 0).as_bool();
});
return ret;
if (scan_params_.limit == std::numeric_limits<int32_t>::max()) {
return Scan::scan_vertex(
graph, scan_params_, [&expr](label_t label, vid_t vid) {
return expr->eval_vertex(label, vid, 0, 0).as_bool();
});
} else {
return Scan::scan_vertex_with_limit(
graph, scan_params_, [&expr](label_t label, vid_t vid) {
return expr->eval_vertex(label, vid, 0, 0).as_bool();
});
}
} else {
auto ret = Scan::scan_vertex(
graph, scan_params_, [&expr](label_t label, vid_t vid) {
return expr->eval_vertex(label, vid, 0).as_bool();
});
return ret;
if (scan_params_.limit == std::numeric_limits<int32_t>::max()) {
auto ret = Scan::scan_vertex(
graph, scan_params_, [&expr](label_t label, vid_t vid) {
return expr->eval_vertex(label, vid, 0).as_bool();
});
return ret;
} else {
auto ret = Scan::scan_vertex_with_limit(
graph, scan_params_, [&expr](label_t label, vid_t vid) {
return expr->eval_vertex(label, vid, 0).as_bool();
});
return ret;
}
}
}

Expand All @@ -523,8 +537,13 @@
const std::map<std::string, std::string>& params,
gs::runtime::Context&& ctx,
gs::runtime::OprTimer& timer) override {
return Scan::scan_vertex(graph, scan_params_,
[](label_t, vid_t) { return true; });
if (scan_params_.limit == std::numeric_limits<int32_t>::max()) {
return Scan::scan_vertex(graph, scan_params_,
[](label_t, vid_t) { return true; });
} else {
return Scan::scan_vertex_with_limit(graph, scan_params_,
[](label_t, vid_t) { return true; });
}
}

private:
Expand All @@ -532,152 +551,163 @@
};

auto parse_ids_with_type(PropertyType type,
const algebra::IndexPredicate& triplet) {
std::function<std::vector<Any>(ParamsType)> ids;
switch (type.type_enum) {
case impl::PropertyTypeImpl::kInt64: {
parse_ids_from_idx_predicate<int64_t>(triplet, ids);
} break;
case impl::PropertyTypeImpl::kInt32: {
parse_ids_from_idx_predicate<int32_t>(triplet, ids);
} break;
case impl::PropertyTypeImpl::kStringView: {
parse_ids_from_idx_predicate(triplet, ids);
} break;
default:
LOG(FATAL) << "unsupported type" << static_cast<int>(type.type_enum);
break;
}
return ids;
}

std::pair<std::unique_ptr<IReadOperator>, ContextMeta> ScanOprBuilder::Build(
const gs::Schema& schema, const ContextMeta& ctx_meta,
const physical::PhysicalPlan& plan, int op_idx) {
ContextMeta ret_meta;
int alias = -1;
if (plan.plan(op_idx).opr().scan().has_alias()) {
alias = plan.plan(op_idx).opr().scan().alias().value();
}
ret_meta.set(alias);
auto scan_opr = plan.plan(op_idx).opr().scan();
CHECK(scan_opr.scan_opt() == physical::Scan::VERTEX);
CHECK(scan_opr.has_params());

ScanParams scan_params;
scan_params.alias = scan_opr.has_alias() ? scan_opr.alias().value() : -1;
scan_params.limit = std::numeric_limits<int32_t>::max();
if (scan_opr.params().has_limit()) {
auto& limit_range = scan_opr.params().limit();
if (limit_range.lower() != 0) {
LOG(FATAL) << "Scan with lower limit expect 0, but got "
<< limit_range.lower();
}
if (limit_range.upper() > 0) {
scan_params.limit = limit_range.upper();
}
}
for (auto& table : scan_opr.params().tables()) {
// bug here, exclude invalid vertex label id
if (schema.vertex_label_num() <= table.id()) {
continue;
}
scan_params.tables.emplace_back(table.id());
}
if (scan_opr.has_idx_predicate()) {
bool scan_oid = false;
CHECK(check_idx_predicate(scan_opr, scan_oid));
// only one label and without predicate
if (scan_params.tables.size() == 1 && scan_oid &&
(!scan_opr.params().has_predicate())) {
const auto& pks = schema.get_vertex_primary_key(scan_params.tables[0]);
const auto& [type, _, __] = pks[0];
auto oids = parse_ids_with_type(type, scan_opr.idx_predicate());
return std::make_pair(
std::make_unique<FilterOidsWithoutPredOpr>(scan_params, oids),
ret_meta);
}

// without predicate
if (!scan_opr.params().has_predicate()) {
if (!scan_oid) {
auto gids =
parse_ids_with_type(PropertyType::kInt64, scan_opr.idx_predicate());
return std::make_pair(
std::make_unique<FilterGidsWithoutPredOpr>(scan_params, gids),
ret_meta);
} else {
std::vector<std::function<std::vector<Any>(ParamsType)>> oids;
std::set<int> types;
for (auto& table : scan_params.tables) {
const auto& pks = schema.get_vertex_primary_key(table);
const auto& [type, _, __] = pks[0];
int type_impl = static_cast<int>(type.type_enum);
if (types.find(type_impl) == types.end()) {
types.insert(type_impl);
const auto& oid =
parse_ids_with_type(type, scan_opr.idx_predicate());
oids.emplace_back(oid);
}
}
if (types.size() == 1) {
return std::make_pair(
std::make_unique<FilterOidsWithoutPredOpr>(scan_params, oids[0]),
ret_meta);
} else {
return std::make_pair(
std::make_unique<FilterMultiTypeOidsWithoutPredOpr>(scan_params,
oids),
ret_meta);
}
}
} else {
auto sp_vertex_pred =
parse_special_vertex_predicate(scan_opr.params().predicate());
if (scan_oid) {
std::set<int> types;
std::vector<std::function<std::vector<Any>(ParamsType)>> oids;
for (auto& table : scan_params.tables) {
const auto& pks = schema.get_vertex_primary_key(table);
const auto& [type, _, __] = pks[0];
auto type_impl = static_cast<int>(type.type_enum);
if (types.find(type_impl) == types.end()) {
auto oid = parse_ids_with_type(type, scan_opr.idx_predicate());
types.insert(type_impl);
oids.emplace_back(oid);
}
}
if (types.size() == 1) {
if (sp_vertex_pred.has_value()) {
return std::make_pair(std::make_unique<FilterOidsSPredOpr>(
scan_params, oids[0], *sp_vertex_pred),
ret_meta);
} else {
return std::make_pair(
std::make_unique<FilterOidsGPredOpr>(
scan_params, oids[0], scan_opr.params().predicate()),
ret_meta);
}
} else {
if (sp_vertex_pred.has_value()) {
return std::make_pair(std::make_unique<FilterOidsMultiTypeSPredOpr>(
scan_params, oids, *sp_vertex_pred),
ret_meta);
} else {
return std::make_pair(
std::make_unique<FilterOidsMultiTypeGPredOpr>(
scan_params, oids, scan_opr.params().predicate()),
ret_meta);
}
}

} else {
auto gids =
parse_ids_with_type(PropertyType::kInt64, scan_opr.idx_predicate());
if (sp_vertex_pred.has_value()) {
return std::make_pair(std::make_unique<FilterGidsSPredOpr>(
scan_params, gids, *sp_vertex_pred),
ret_meta);
} else {
return std::make_pair(
std::make_unique<FilterGidsGPredOpr>(
scan_params, gids, scan_opr.params().predicate()),
ret_meta);
}
}
}

} else {
if (scan_opr.params().has_predicate()) {

Check notice on line 710 in flex/engines/graph_db/runtime/execute/ops/retrieve/scan.cc

View check run for this annotation

codefactor.io / CodeFactor

flex/engines/graph_db/runtime/execute/ops/retrieve/scan.cc#L554-L710

Complex Method
auto sp_vertex_pred =
parse_special_vertex_predicate(scan_opr.params().predicate());
if (sp_vertex_pred.has_value()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,40 @@ def test_call_proc_in_cypher(interactive_session, neo4j_session, create_modern_g
assert cnt == 8


@pytest.mark.skipif(
os.environ.get("RUN_ON_PROTO", None) != "ON",
reason="Scan+Limit fuse only works on proto",
)
def test_scan_limit_fuse(interactive_session, neo4j_session, create_modern_graph):
print("[Test call procedure in cypher]")
import_data_to_full_modern_graph(interactive_session, create_modern_graph)
start_service_on_graph(interactive_session, create_modern_graph)
ensure_compiler_schema_ready(
interactive_session, neo4j_session, create_modern_graph
)
result = neo4j_session.run(
'MATCH(p: person) with p.id as oid CALL k_neighbors("person", oid, 1) return label_name, vertex_oid;'
)
cnt = 0
for record in result:
cnt += 1
assert cnt == 8

# Q: Why we could use this query to verify whether Scan+Limit fuse works?
# A: If Scan+Limit fuse works, the result of this query should be 2, otherwise it should be 6
result = neo4j_session.run("MATCH(n) return n.id limit 2")
cnt = 0
for record in result:
cnt += 1
assert cnt == 2

result = neo4j_session.run("MATCH(n) return n.id limit 0")
cnt = 0
for record in result:
cnt += 1
assert cnt == 0


def test_custom_pk_name(
interactive_session, neo4j_session, create_graph_with_custom_pk_name
):
Expand Down
3 changes: 3 additions & 0 deletions flex/tests/hqps/interactive_config_test_cbo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ compiler:
- NotMatchToAntiJoinRule
- ExtendIntersectRule
- ExpandGetVFusionRule
- ScanExpandFusionRule
- TopKPushDownRule
- ScanEarlyStopRule # This rule must be placed after TopKPushDownRule and ScanExpandFusionRule
meta:
reader:
schema:
Expand Down
Loading
Loading