Skip to content

Commit

Permalink
[Feature] support array_agg in window function StarRocks#40881
Browse files Browse the repository at this point in the history
  • Loading branch information
mygrsun2 committed May 5, 2024
1 parent bcb4c18 commit 3eb226c
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 5 deletions.
4 changes: 2 additions & 2 deletions be/src/exec/analytor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ Status Analytor::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile*
if (fn.name.function_name == "count" || fn.name.function_name == "row_number" ||
fn.name.function_name == "rank" || fn.name.function_name == "dense_rank" ||
fn.name.function_name == "cume_dist" || fn.name.function_name == "percent_rank" ||
fn.name.function_name == "ntile") {
fn.name.function_name == "ntile" || fn.name.function_name == "agg") {
auto return_type = TYPE_BIGINT;
if (fn.name.function_name == "cume_dist" || fn.name.function_name == "percent_rank") {
if (fn.name.function_name == "cume_dist" || fn.name.function_name == "percent_rank" || fn.name.function_name == "agg") {
return_type = TYPE_DOUBLE;
}
is_input_nullable = !fn.arg_types.empty() && (desc.nodes[0].has_nullable_child || has_outer_join_child);
Expand Down
2 changes: 1 addition & 1 deletion be/src/exec/analytor.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class Analytor final : public pipeline::ContextWithDependency {
// it's necessary to specify the size of the partition.
void _set_partition_size_for_function();
bool _require_partition_size(const std::string& function_name) {
return function_name == "cume_dist" || function_name == "percent_rank";
return function_name == "cume_dist" || function_name == "percent_rank" || function_name == "agg";
}

RuntimeState* _state = nullptr;
Expand Down
4 changes: 4 additions & 0 deletions be/src/exprs/agg/factory/aggregate_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ AggregateFunctionPtr AggregateFactory::MakePercentRankWindowFunction() {
return std::make_shared<PercentRankWindowFunction>();
}

AggregateFunctionPtr AggregateFactory::MakeAggWindowFunction() {
return std::make_shared<AggWindowFunction>();
}

AggregateFunctionPtr AggregateFactory::MakeNtileWindowFunction() {
return std::make_shared<NtileWindowFunction>();
}
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/agg/factory/aggregate_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ class AggregateFactory {

static AggregateFunctionPtr MakePercentRankWindowFunction();

static AggregateFunctionPtr MakeAggWindowFunction();

static AggregateFunctionPtr MakeNtileWindowFunction();

template <LogicalType LT, bool ignoreNulls>
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/agg/factory/aggregate_resolver_window.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ void AggregateFuncResolver::register_window() {
AggregateFactory::MakeCumeDistWindowFunction());
add_aggregate_mapping_notnull<TYPE_BIGINT, TYPE_DOUBLE>("percent_rank", true,
AggregateFactory::MakePercentRankWindowFunction());
add_aggregate_mapping_notnull<TYPE_BIGINT, TYPE_DOUBLE>("agg", true,
AggregateFactory::MakeAggWindowFunction());
add_aggregate_mapping_notnull<TYPE_BIGINT, TYPE_BIGINT>("row_number", true,
AggregateFactory::MakeRowNumberWindowFunction());
add_aggregate_mapping_notnull<TYPE_BIGINT, TYPE_BIGINT>("ntile", true, AggregateFactory::MakeNtileWindowFunction());
Expand Down
45 changes: 45 additions & 0 deletions be/src/exprs/agg/window.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class CumeDistWindowFunction final : public WindowFunction<CumeDistState> {
size_t end) const override {
DCHECK_GT(end, start);
auto& s = this->data(state);
LOG(INFO) << "erictest" << 9999 << " ttt";
auto* column = down_cast<DoubleColumn*>(dst);
for (size_t i = start; i < end; ++i) {
column->get_data()[i] = (double)s.rank / s.count;
Expand Down Expand Up @@ -331,6 +332,50 @@ class PercentRankWindowFunction final : public WindowFunction<PercentRankState>
std::string get_name() const override { return "percent_rank"; }
};


struct AggState {
int64_t rank;
int64_t peer_group_start;
int64_t count;
};

class AggWindowFunction final : public WindowFunction<AggState> {
void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state) const override {
auto& s = this->data(state);
s.rank = 0;
s.peer_group_start = -1;
s.count = 1;
}

void reset_state_for_contraction(FunctionContext* ctx, AggDataPtr __restrict state, size_t count) const override {
this->data(state).peer_group_start -= count;
}

void update_batch_single_state_with_frame(FunctionContext* ctx, AggDataPtr __restrict state, const Column** columns,
int64_t peer_group_start, int64_t peer_group_end, int64_t frame_start,
int64_t frame_end) const override {
auto& s = this->data(state);
if (s.peer_group_start != peer_group_start) {
s.peer_group_start = peer_group_start;
int64_t peer_group_count = peer_group_end - peer_group_start;
s.rank += peer_group_count;
}
}

void get_values(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* dst, size_t start,
size_t end) const override {
DCHECK_GT(end, start);
auto& s = this->data(state);
LOG(INFO) << "erictest2t" << 9999 << " ttt";
auto* column = down_cast<DoubleColumn*>(dst);
for (size_t i = start; i < end; ++i) {
column->get_data()[i] = (double)s.rank / s.count;
}
}

std::string get_name() const override { return "agg"; }
};

// The NTILE window function divides ordered rows in the partition into `num_buckets` ranked groups
// of as equal size as possible and returns the group id of each row starting from 1.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ public class AnalyticExpr extends Expr {
public static String ROWNUMBER = "ROW_NUMBER";
public static String CUMEDIST = "CUME_DIST";
public static String PERCENTRANK = "PERCENT_RANK";
public static String AGG = "AGG";
public static String NTILE = "NTILE";
public static String MIN = "MIN";
public static String MAX = "MAX";
Expand Down Expand Up @@ -277,6 +278,14 @@ public static boolean isCumeFn(Function fn) {
return fn.functionName().equalsIgnoreCase(CUMEDIST) || fn.functionName().equalsIgnoreCase(PERCENTRANK);
}

public static boolean isAggFn(Function fn) {
if (!isAnalyticFn(fn)) {
return false;
}

return fn.functionName().equalsIgnoreCase(AGG);
}

public static boolean isRowNumberFn(Function fn) {
if (!isAnalyticFn(fn)) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ public class FunctionSet {
public static final String RANK = "rank";
public static final String CUME_DIST = "cume_dist";
public static final String PERCENT_RANK = "percent_rank";
public static final String AGG = "agg";
public static final String NTILE = "ntile";
public static final String ROW_NUMBER = "row_number";
public static final String SESSION_NUMBER = "session_number";
Expand Down Expand Up @@ -646,6 +647,7 @@ public class FunctionSet {
.add(FunctionSet.DENSE_RANK)
.add(FunctionSet.RANK)
.add(FunctionSet.CUME_DIST)
.add(FunctionSet.AGG)
.add(FunctionSet.PERCENT_RANK)
.add(FunctionSet.NTILE)
.add(FunctionSet.ROW_NUMBER)
Expand Down Expand Up @@ -1149,6 +1151,9 @@ private void initAggregateBuiltins() {
// Percent rank
addBuiltin(AggregateFunction.createAnalyticBuiltin(PERCENT_RANK,
Collections.emptyList(), Type.DOUBLE, Type.VARBINARY));
//Agg
addBuiltin(AggregateFunction.createAnalyticBuiltin(AGG,
Collections.emptyList(), Type.DOUBLE, Type.VARBINARY));
// Ntile
addBuiltin(AggregateFunction.createAnalyticBuiltin(NTILE,
Lists.newArrayList(Type.BIGINT), Type.BIGINT, Type.BIGINT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ public static void verifyAnalyticExpression(AnalyticExpr analyticExpr) {

if (analyticExpr.getWindow() != null) {
if ((isRankingFn(analyticFunction.getFn()) || isCumeFn(analyticFunction.getFn()) ||
isOffsetFn(analyticFunction.getFn()) || isHllAggFn(analyticFunction.getFn()))) {
isOffsetFn(analyticFunction.getFn()) || isHllAggFn(analyticFunction.getFn()) ||
isAggFn(analyticFunction.getFn()))) {
throw new SemanticException("Windowing clause not allowed with '" + analyticFunction.toSql() + "'",
analyticExpr.getPos());
}
Expand Down Expand Up @@ -346,6 +347,14 @@ private static boolean isCumeFn(Function fn) {
|| fn.functionName().equalsIgnoreCase(AnalyticExpr.PERCENTRANK);
}

private static boolean isAggFn(Function fn) {
if (!isAnalyticFn(fn)) {
return false;
}

return fn.functionName().equalsIgnoreCase(AnalyticExpr.AGG);
}

private static boolean isNtileFn(Function fn) {
if (!isAnalyticFn(fn)) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ public static WindowOperator standardize(AnalyticExpr analyticExpr) {
Preconditions.checkState(windowFrame == null, "Unexpected window set for "
+ callExpr.getFn().getFunctionName() + "()");
windowFrame = AnalyticWindow.DEFAULT_WINDOW;
} else if (AnalyticExpr.isAggFn(callExpr.getFn())) {
Preconditions.checkState(windowFrame == null, "Unexpected window set for "
+ callExpr.getFn().getFunctionName() + "()");
windowFrame = AnalyticWindow.DEFAULT_WINDOW;
} else if (AnalyticExpr.isOffsetFn(callExpr.getFn())) {
try {
Preconditions.checkState(windowFrame == null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2289,6 +2289,7 @@ windowFunction
| name = RANK '(' ')'
| name = DENSE_RANK '(' ')'
| name = CUME_DIST '(' ')'
| name = AGG '(' ')'
| name = PERCENT_RANK '(' ')'
| name = NTILE '(' expression? ')'
| name = LEAD '(' (expression ignoreNulls? (',' expression)*)? ')' ignoreNulls?
Expand Down Expand Up @@ -2632,7 +2633,7 @@ nonReserved
| BACKEND | BACKENDS | BACKUP | BEGIN | BITMAP_UNION | BLACKLIST | BLACKHOLE | BINARY | BODY | BOOLEAN | BROKER | BUCKETS
| BUILTIN | BASE | BEFORE
| CACHE | CAST | CANCEL | CATALOG | CATALOGS | CEIL | CHAIN | CHARSET | CLEAN | CLEAR | CLUSTER | CLUSTERS | CURRENT | COLLATION | COLUMNS
| CUME_DIST | CUMULATIVE | COMMENT | COMMIT | COMMITTED | COMPUTE | CONNECTION | CONSISTENT | COSTS | COUNT
| CUME_DIST | AGG | CUMULATIVE | COMMENT | COMMIT | COMMITTED | COMPUTE | CONNECTION | CONSISTENT | COSTS | COUNT
| CONFIG | COMPACT
| DATA | DATE | DATACACHE | DATETIME | DAY | DECOMMISSION | DISABLE | DISK | DISTRIBUTION | DUPLICATE | DYNAMIC | DISTRIBUTED | DICTIONARY | DICTIONARY_GET | DEALLOCATE
| ENABLE | END | ENGINE | ENGINES | ERRORS | EVENTS | EXECUTE | EXTERNAL | EXTRACT | EVERY | ENCLOSE | ESCAPE | EXPORT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ CREATE: 'CREATE';
CROSS: 'CROSS';
CUBE: 'CUBE';
CUME_DIST: 'CUME_DIST';
AGG: 'AGG';
CUMULATIVE: 'CUMULATIVE';
CURRENT: 'CURRENT';
CURRENT_DATE: 'CURRENT_DATE';
Expand Down
1 change: 1 addition & 0 deletions gensrc/thrift/PlanNodes.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ enum TAggregationOp {
ANY_VALUE,
NTILE,
CUME_DIST,
AGG,
PERCENT_RANK
}

Expand Down

0 comments on commit 3eb226c

Please sign in to comment.