From 8466df70d632331d77b9cb6fb4c595c2bbfef3cf Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Mon, 2 Dec 2024 12:48:02 +0800 Subject: [PATCH] fix(query): keep remaining_predicates when filtering grouping sets (#16971) * fix(query): keep remaining_predicates when filtering grouping sets * update * update * update * update * Update 03_0003_select_group_by.test * update --------- Co-authored-by: TCeason <33082201+TCeason@users.noreply.github.com> --- src/query/ast/src/ast/format/syntax/query.rs | 9 + src/query/ast/src/ast/query.rs | 86 ++- src/query/ast/src/parser/query.rs | 29 +- src/query/ast/tests/it/parser.rs | 4 + src/query/ast/tests/it/testdata/stmt.txt | 543 ++++++++++++++++++ src/query/expression/benches/bench.rs | 26 + src/query/functions/src/scalars/arithmetic.rs | 43 +- src/query/functions/src/scalars/other.rs | 10 +- src/query/sql/src/planner/binder/aggregate.rs | 95 ++- .../rule_push_down_filter_aggregate.rs | 7 +- src/tests/sqlsmith/src/sql_gen/query.rs | 11 +- .../03_common/03_0003_select_group_by.test | 92 ++- 12 files changed, 863 insertions(+), 92 deletions(-) diff --git a/src/query/ast/src/ast/format/syntax/query.rs b/src/query/ast/src/ast/format/syntax/query.rs index deb6fafe30da2..0e93bfc94acb9 100644 --- a/src/query/ast/src/ast/format/syntax/query.rs +++ b/src/query/ast/src/ast/format/syntax/query.rs @@ -272,6 +272,15 @@ fn pretty_group_by(group_by: Option) -> RcDoc<'static> { ) .append(RcDoc::line()) .append(RcDoc::text(")")), + + GroupBy::Combined(sets) => RcDoc::line() + .append(RcDoc::text("GROUP BY ").append(RcDoc::line().nest(NEST_FACTOR))) + .append( + interweave_comma(sets.into_iter().map(|s| RcDoc::text(s.to_string()))) + .nest(NEST_FACTOR) + .group(), + ) + .append(RcDoc::line()), } } else { RcDoc::nil() diff --git a/src/query/ast/src/ast/query.rs b/src/query/ast/src/ast/query.rs index 76893ba296234..165b66fd6b220 100644 --- a/src/query/ast/src/ast/query.rs +++ b/src/query/ast/src/ast/query.rs @@ -189,38 +189,8 @@ impl Display for SelectStmt { // GROUP BY clause if self.group_by.is_some() { write!(f, " GROUP BY ")?; - match self.group_by.as_ref().unwrap() { - GroupBy::Normal(exprs) => { - write_comma_separated_list(f, exprs)?; - } - GroupBy::All => { - write!(f, "ALL")?; - } - GroupBy::GroupingSets(sets) => { - write!(f, "GROUPING SETS (")?; - for (i, set) in sets.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "(")?; - write_comma_separated_list(f, set)?; - write!(f, ")")?; - } - write!(f, ")")?; - } - GroupBy::Cube(exprs) => { - write!(f, "CUBE (")?; - write_comma_separated_list(f, exprs)?; - write!(f, ")")?; - } - GroupBy::Rollup(exprs) => { - write!(f, "ROLLUP (")?; - write_comma_separated_list(f, exprs)?; - write!(f, ")")?; - } - } + write!(f, "{}", self.group_by.as_ref().unwrap())?; } - // HAVING clause if let Some(having) = &self.having { write!(f, " HAVING {having}")?; @@ -254,6 +224,60 @@ pub enum GroupBy { Cube(Vec), /// GROUP BY ROLLUP ( expr [, expr]* ) Rollup(Vec), + Combined(Vec), +} + +impl GroupBy { + pub fn normal_items(&self) -> Vec { + match self { + GroupBy::Normal(exprs) => exprs.clone(), + _ => Vec::new(), + } + } +} + +impl Display for GroupBy { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + GroupBy::Normal(exprs) => { + write_comma_separated_list(f, exprs)?; + } + GroupBy::All => { + write!(f, "ALL")?; + } + GroupBy::GroupingSets(sets) => { + write!(f, "GROUPING SETS (")?; + for (i, set) in sets.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "(")?; + write_comma_separated_list(f, set)?; + write!(f, ")")?; + } + write!(f, ")")?; + } + GroupBy::Cube(exprs) => { + write!(f, "CUBE (")?; + write_comma_separated_list(f, exprs)?; + write!(f, ")")?; + } + GroupBy::Rollup(exprs) => { + write!(f, "ROLLUP (")?; + write_comma_separated_list(f, exprs)?; + write!(f, ")")?; + } + GroupBy::Combined(group_bys) => { + for (i, group_by) in group_bys.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", group_by)?; + } + } + } + Ok(()) + } } /// A relational set expression, like `SELECT ... FROM ... {UNION|EXCEPT|INTERSECT} SELECT ... FROM ...` diff --git a/src/query/ast/src/parser/query.rs b/src/query/ast/src/parser/query.rs index cd9cb44be96b9..efa318b117267 100644 --- a/src/query/ast/src/parser/query.rs +++ b/src/query/ast/src/parser/query.rs @@ -1073,10 +1073,6 @@ impl<'a, I: Iterator>> PrattParser } pub fn group_by_items(i: Input) -> IResult { - let normal = map(rule! { ^#comma_separated_list1(expr) }, |groups| { - GroupBy::Normal(groups) - }); - let all = map(rule! { ALL }, |_| GroupBy::All); let cube = map( @@ -1096,10 +1092,31 @@ pub fn group_by_items(i: Input) -> IResult { map(rule! { #expr }, |e| vec![e]), )); let group_sets = map( - rule! { GROUPING ~ SETS ~ "(" ~ ^#comma_separated_list1(group_set) ~ ")" }, + rule! { GROUPING ~ ^SETS ~ "(" ~ ^#comma_separated_list1(group_set) ~ ")" }, |(_, _, _, sets, _)| GroupBy::GroupingSets(sets), ); - rule!(#all | #group_sets | #cube | #rollup | #normal)(i) + + // New rule to handle multiple GroupBy items + let single_normal = map(rule! { #expr }, |group| GroupBy::Normal(vec![group])); + let group_by_item = alt((all, group_sets, cube, rollup, single_normal)); + map(rule! { ^#comma_separated_list1(group_by_item) }, |items| { + if items.len() > 1 { + if items.iter().all(|item| matches!(item, GroupBy::Normal(_))) { + let items = items + .into_iter() + .flat_map(|item| match item { + GroupBy::Normal(exprs) => exprs, + _ => unreachable!(), + }) + .collect(); + GroupBy::Normal(items) + } else { + GroupBy::Combined(items) + } + } else { + items.into_iter().next().unwrap() + } + })(i) } pub fn window_frame_bound(i: Input) -> IResult { diff --git a/src/query/ast/tests/it/parser.rs b/src/query/ast/tests/it/parser.rs index 97a63ef956b50..9b43de4881879 100644 --- a/src/query/ast/tests/it/parser.rs +++ b/src/query/ast/tests/it/parser.rs @@ -591,12 +591,16 @@ fn test_statement() { "#, r#"SHOW FILE FORMATS"#, r#"DROP FILE FORMAT my_csv"#, + r#"SELECT * FROM t GROUP BY all"#, + r#"SELECT * FROM t GROUP BY a, b, c, d"#, r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, c, d)"#, r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, (c, d))"#, r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (c), (d, e))"#, r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (), (d, e))"#, r#"SELECT * FROM t GROUP BY CUBE (a, b, c)"#, r#"SELECT * FROM t GROUP BY ROLLUP (a, b, c)"#, + r#"SELECT * FROM t GROUP BY a, ROLLUP (b, c)"#, + r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b)), a, ROLLUP (b, c)"#, r#"CREATE MASKING POLICY email_mask AS (val STRING) RETURNS STRING -> CASE WHEN current_role() IN ('ANALYST') THEN VAL ELSE '*********'END comment = 'this is a masking policy'"#, r#"CREATE OR REPLACE MASKING POLICY email_mask AS (val STRING) RETURNS STRING -> CASE WHEN current_role() IN ('ANALYST') THEN VAL ELSE '*********'END comment = 'this is a masking policy'"#, r#"DESC MASKING POLICY email_mask"#, diff --git a/src/query/ast/tests/it/testdata/stmt.txt b/src/query/ast/tests/it/testdata/stmt.txt index f5f9c63e1ba67..63fdfca7b5313 100644 --- a/src/query/ast/tests/it/testdata/stmt.txt +++ b/src/query/ast/tests/it/testdata/stmt.txt @@ -17455,6 +17455,227 @@ DropFileFormat { } +---------- Input ---------- +SELECT * FROM t GROUP BY all +---------- Output --------- +SELECT * FROM t GROUP BY ALL +---------- AST ------------ +Query( + Query { + span: Some( + 0..28, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..28, + ), + hints: None, + distinct: false, + top_n: None, + select_list: [ + StarColumns { + qualified: [ + Star( + Some( + 7..8, + ), + ), + ], + column_filter: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + span: Some( + 14..15, + ), + name: "t", + quote: None, + ident_type: None, + }, + alias: None, + temporal: None, + with_options: None, + pivot: None, + unpivot: None, + sample: None, + }, + ], + selection: None, + group_by: Some( + All, + ), + having: None, + window_list: None, + qualify: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + +---------- Input ---------- +SELECT * FROM t GROUP BY a, b, c, d +---------- Output --------- +SELECT * FROM t GROUP BY a, b, c, d +---------- AST ------------ +Query( + Query { + span: Some( + 0..35, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..35, + ), + hints: None, + distinct: false, + top_n: None, + select_list: [ + StarColumns { + qualified: [ + Star( + Some( + 7..8, + ), + ), + ], + column_filter: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + span: Some( + 14..15, + ), + name: "t", + quote: None, + ident_type: None, + }, + alias: None, + temporal: None, + with_options: None, + pivot: None, + unpivot: None, + sample: None, + }, + ], + selection: None, + group_by: Some( + Normal( + [ + ColumnRef { + span: Some( + 25..26, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 25..26, + ), + name: "a", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 28..29, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 28..29, + ), + name: "b", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 31..32, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 31..32, + ), + name: "c", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 34..35, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 34..35, + ), + name: "d", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + ), + having: None, + window_list: None, + qualify: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + ---------- Input ---------- SELECT * FROM t GROUP BY GROUPING SETS (a, b, c, d) ---------- Output --------- @@ -18361,6 +18582,328 @@ Query( ) +---------- Input ---------- +SELECT * FROM t GROUP BY a, ROLLUP (b, c) +---------- Output --------- +SELECT * FROM t GROUP BY a, ROLLUP (b, c) +---------- AST ------------ +Query( + Query { + span: Some( + 0..41, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..41, + ), + hints: None, + distinct: false, + top_n: None, + select_list: [ + StarColumns { + qualified: [ + Star( + Some( + 7..8, + ), + ), + ], + column_filter: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + span: Some( + 14..15, + ), + name: "t", + quote: None, + ident_type: None, + }, + alias: None, + temporal: None, + with_options: None, + pivot: None, + unpivot: None, + sample: None, + }, + ], + selection: None, + group_by: Some( + Combined( + [ + Normal( + [ + ColumnRef { + span: Some( + 25..26, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 25..26, + ), + name: "a", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + Rollup( + [ + ColumnRef { + span: Some( + 36..37, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 36..37, + ), + name: "b", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 39..40, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 39..40, + ), + name: "c", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + ], + ), + ), + having: None, + window_list: None, + qualify: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + +---------- Input ---------- +SELECT * FROM t GROUP BY GROUPING SETS ((a, b)), a, ROLLUP (b, c) +---------- Output --------- +SELECT * FROM t GROUP BY GROUPING SETS ((a, b)), a, ROLLUP (b, c) +---------- AST ------------ +Query( + Query { + span: Some( + 0..65, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..65, + ), + hints: None, + distinct: false, + top_n: None, + select_list: [ + StarColumns { + qualified: [ + Star( + Some( + 7..8, + ), + ), + ], + column_filter: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + span: Some( + 14..15, + ), + name: "t", + quote: None, + ident_type: None, + }, + alias: None, + temporal: None, + with_options: None, + pivot: None, + unpivot: None, + sample: None, + }, + ], + selection: None, + group_by: Some( + Combined( + [ + GroupingSets( + [ + [ + ColumnRef { + span: Some( + 41..42, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 41..42, + ), + name: "a", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 44..45, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 44..45, + ), + name: "b", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ], + ), + Normal( + [ + ColumnRef { + span: Some( + 49..50, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 49..50, + ), + name: "a", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + Rollup( + [ + ColumnRef { + span: Some( + 60..61, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 60..61, + ), + name: "b", + quote: None, + ident_type: None, + }, + ), + }, + }, + ColumnRef { + span: Some( + 63..64, + ), + column: ColumnRef { + database: None, + table: None, + column: Name( + Identifier { + span: Some( + 63..64, + ), + name: "c", + quote: None, + ident_type: None, + }, + ), + }, + }, + ], + ), + ], + ), + ), + having: None, + window_list: None, + qualify: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + ---------- Input ---------- CREATE MASKING POLICY email_mask AS (val STRING) RETURNS STRING -> CASE WHEN current_role() IN ('ANALYST') THEN VAL ELSE '*********'END comment = 'this is a masking policy' ---------- Output --------- diff --git a/src/query/expression/benches/bench.rs b/src/query/expression/benches/bench.rs index bfc342bc322aa..a92e7d30f4230 100644 --- a/src/query/expression/benches/bench.rs +++ b/src/query/expression/benches/bench.rs @@ -20,11 +20,14 @@ use criterion::Criterion; use databend_common_column::buffer::Buffer; use databend_common_expression::arrow::deserialize_column; use databend_common_expression::arrow::serialize_column; +use databend_common_expression::types::ArgType; use databend_common_expression::types::BinaryType; +use databend_common_expression::types::Int32Type; use databend_common_expression::types::StringType; use databend_common_expression::Column; use databend_common_expression::DataBlock; use databend_common_expression::FromData; +use databend_common_expression::Value; use rand::rngs::StdRng; use rand::Rng; use rand::SeedableRng; @@ -222,6 +225,29 @@ fn bench(c: &mut Criterion) { .collect::>(); }) }); + + let value1 = Value::::Column(left.clone()); + let value2 = Value::::Column(right.clone()); + + group.bench_function(format!("register_new/{length}"), |b| { + b.iter(|| { + let iter = (0..length).map(|i| { + let a = unsafe { value1.index_unchecked(i) }; + let b = unsafe { value2.index_unchecked(i) }; + a + b + }); + let _c = Int32Type::column_from_iter(iter, &[]); + }) + }); + + group.bench_function(format!("register_old/{length}"), |b| { + b.iter(|| { + let a = value1.as_column().unwrap(); + let b = value2.as_column().unwrap(); + let iter = a.iter().zip(b.iter()).map(|(a, b)| *a + b); + let _c = Int32Type::column_from_iter(iter, &[]); + }) + }); } } diff --git a/src/query/functions/src/scalars/arithmetic.rs b/src/query/functions/src/scalars/arithmetic.rs index bd42635b3e065..e3c60a6c3c84d 100644 --- a/src/query/functions/src/scalars/arithmetic.rs +++ b/src/query/functions/src/scalars/arithmetic.rs @@ -21,7 +21,6 @@ use std::str::FromStr; use std::sync::Arc; use databend_common_expression::serialize::read_decimal_with_size; -use databend_common_expression::types::binary::BinaryColumnBuilder; use databend_common_expression::types::decimal::DecimalDomain; use databend_common_expression::types::decimal::DecimalType; use databend_common_expression::types::nullable::NullableColumn; @@ -38,7 +37,6 @@ use databend_common_expression::types::NullableType; use databend_common_expression::types::NumberClass; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::SimpleDomain; -use databend_common_expression::types::StringColumn; use databend_common_expression::types::StringType; use databend_common_expression::types::ALL_FLOAT_TYPES; use databend_common_expression::types::ALL_INTEGER_TYPES; @@ -972,25 +970,24 @@ pub fn register_number_to_string(registry: &mut FunctionRegistry) { Value::Column(from) => { let options = NUM_TYPE::lexical_options(); const FORMAT: u128 = lexical_core::format::STANDARD; - type Native = ::Native; - let mut builder = StringColumnBuilder::with_capacity(from.len()); unsafe { + builder.row_buffer.resize( + ::Native::FORMATTED_SIZE_DECIMAL, + 0, + ); for x in from.iter() { - builder.row_buffer.resize( - ::Native::FORMATTED_SIZE_DECIMAL, - 0, - ); let len = lexical_core::write_with_options::<_, FORMAT>( Native::from(*x), - &mut builder.row_buffer, + &mut builder.row_buffer[0..], &options, ) .len(); - builder.row_buffer.truncate(len); - builder.commit_row(); + builder.data.push_value(std::str::from_utf8_unchecked( + &builder.row_buffer[0..len], + )); } } Value::Column(builder.build()) @@ -1005,29 +1002,27 @@ pub fn register_number_to_string(registry: &mut FunctionRegistry) { Value::Column(from) => { let options = NUM_TYPE::lexical_options(); const FORMAT: u128 = lexical_core::format::STANDARD; - let mut builder = - BinaryColumnBuilder::with_capacity(from.len(), from.len() + 1); - let values = &mut builder.data; - type Native = ::Native; - let mut offset: usize = 0; + let mut builder = StringColumnBuilder::with_capacity(from.len()); + unsafe { + builder.row_buffer.resize( + ::Native::FORMATTED_SIZE_DECIMAL, + 0, + ); for x in from.iter() { - values.reserve(offset + Native::FORMATTED_SIZE_DECIMAL); - values.set_len(offset + Native::FORMATTED_SIZE_DECIMAL); - let bytes = &mut values[offset..]; let len = lexical_core::write_with_options::<_, FORMAT>( Native::from(*x), - bytes, + &mut builder.row_buffer[0..], &options, ) .len(); - offset += len; - builder.offsets.push(offset as u64); + builder.data.push_value(std::str::from_utf8_unchecked( + &builder.row_buffer[0..len], + )); } - values.set_len(offset); } - let result = StringColumn::try_from(builder.build()).unwrap(); + let result = builder.build(); Value::Column(NullableColumn::new( result, Bitmap::new_constant(true, from.len()), diff --git a/src/query/functions/src/scalars/other.rs b/src/query/functions/src/scalars/other.rs index 0c4a44c38d557..ace497ed6a581 100644 --- a/src/query/functions/src/scalars/other.rs +++ b/src/query/functions/src/scalars/other.rs @@ -22,7 +22,6 @@ use databend_common_base::base::convert_number_size; use databend_common_base::base::uuid::Uuid; use databend_common_base::base::OrderedFloat; use databend_common_expression::error_to_null; -use databend_common_expression::types::binary::BinaryColumnBuilder; use databend_common_expression::types::boolean::BooleanDomain; use databend_common_expression::types::nullable::NullableColumn; use databend_common_expression::types::number::Float32Type; @@ -31,7 +30,7 @@ use databend_common_expression::types::number::Int64Type; use databend_common_expression::types::number::UInt32Type; use databend_common_expression::types::number::UInt8Type; use databend_common_expression::types::number::F64; -use databend_common_expression::types::string::StringColumn; +use databend_common_expression::types::string::StringColumnBuilder; use databend_common_expression::types::ArgType; use databend_common_expression::types::DataType; use databend_common_expression::types::DateType; @@ -229,16 +228,15 @@ pub fn register(registry: &mut FunctionRegistry) { "gen_random_uuid", |_| FunctionDomain::Full, |ctx| { - let mut builder = BinaryColumnBuilder::with_capacity(ctx.num_rows, 0); + let mut builder = StringColumnBuilder::with_capacity(ctx.num_rows); for _ in 0..ctx.num_rows { let value = Uuid::now_v7(); - write!(&mut builder.data, "{}", value).unwrap(); + write!(&mut builder.row_buffer, "{}", value).unwrap(); builder.commit_row(); } - let col = StringColumn::try_from(builder.build()).unwrap(); - Value::Column(col) + Value::Column(builder.build()) }, ); } diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index 306223240b7b2..9433c6c4ee437 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -393,7 +393,9 @@ impl Binder { let original_context = bind_context.expr_context.clone(); bind_context.set_expr_context(ExprContext::GroupClaue); - match group_by { + + let group_by = Self::expand_group(group_by.clone())?; + match &group_by { GroupBy::Normal(exprs) => self.resolve_group_items( bind_context, select_list, @@ -416,25 +418,83 @@ impl Binder { GroupBy::GroupingSets(sets) => { self.resolve_grouping_sets(bind_context, select_list, sets, &available_aliases)?; } - // TODO: avoid too many clones. + _ => unreachable!(), + } + bind_context.set_expr_context(original_context); + Ok(()) + } + + pub fn expand_group(group_by: GroupBy) -> Result { + match group_by { + GroupBy::Normal(_) | GroupBy::All | GroupBy::GroupingSets(_) => Ok(group_by), + GroupBy::Cube(exprs) => { + // Expand CUBE to GroupingSets + let sets = Self::generate_cube_sets(exprs); + Ok(GroupBy::GroupingSets(sets)) + } GroupBy::Rollup(exprs) => { - // ROLLUP (a,b,c) => GROUPING SETS ((a,b,c), (a,b), (a), ()) - let mut sets = Vec::with_capacity(exprs.len() + 1); - for i in (0..=exprs.len()).rev() { - sets.push(exprs[0..i].to_vec()); + // Expand ROLLUP to GroupingSets + let sets = Self::generate_rollup_sets(exprs); + Ok(GroupBy::GroupingSets(sets)) + } + GroupBy::Combined(groups) => { + // Flatten and expand all nested GroupBy variants + let mut combined_sets = Vec::new(); + for group in groups { + match Self::expand_group(group)? { + GroupBy::Normal(exprs) => { + combined_sets = Self::cartesian_product(combined_sets, vec![exprs]); + } + GroupBy::GroupingSets(sets) => { + combined_sets = Self::cartesian_product(combined_sets, sets); + } + other => { + return Err(ErrorCode::SyntaxException(format!( + "COMBINED GROUP BY does not support {other:?}" + ))); + } + } } - self.resolve_grouping_sets(bind_context, select_list, &sets, &available_aliases)?; + Ok(GroupBy::GroupingSets(combined_sets)) } - GroupBy::Cube(exprs) => { - // CUBE (a,b) => GROUPING SETS ((a,b),(a),(b),()) // All subsets - let sets = (0..=exprs.len()) - .flat_map(|count| exprs.clone().into_iter().combinations(count)) - .collect::>(); - self.resolve_grouping_sets(bind_context, select_list, &sets, &available_aliases)?; + } + } + + /// Generate GroupingSets from CUBE (expr1, expr2, ...) + fn generate_cube_sets(exprs: Vec) -> Vec> { + (0..=exprs.len()) + .flat_map(|count| exprs.clone().into_iter().combinations(count)) + .collect::>() + } + + /// Generate GroupingSets from ROLLUP (expr1, expr2, ...) + fn generate_rollup_sets(exprs: Vec) -> Vec> { + let mut result = Vec::new(); + for i in (0..=exprs.len()).rev() { + result.push(exprs[..i].to_vec()); + } + result + } + + /// Perform Cartesian product of two sets of grouping sets + fn cartesian_product(set1: Vec>, set2: Vec>) -> Vec> { + if set1.is_empty() { + return set2; + } + + if set2.is_empty() { + return set1; + } + + let mut result = Vec::new(); + for s1 in set1 { + for s2 in &set2 { + let mut combined = s1.clone(); + combined.extend(s2.clone()); + result.push(combined); } } - bind_context.set_expr_context(original_context); - Ok(()) + result } pub fn bind_aggregate( @@ -525,6 +585,11 @@ impl Binder { set }) .collect::>(); + + // Because we are not using union all to implement grouping sets + // We will remove the duplicated grouping sets here. + // For example: SELECT brand, segment, SUM (quantity) FROM sales GROUP BY GROUPING sets(brand, segment), GROUPING sets(brand, segment); + // brand X segment will not appear twice in the result, the results are not standard but acceptable. let grouping_sets = grouping_sets.into_iter().unique().collect(); let mut dup_group_items = Vec::with_capacity(agg_info.group_items.len()); for (i, item) in agg_info.group_items.iter().enumerate() { diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_aggregate.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_aggregate.rs index 599122e2a97d7..06ad9643a7367 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_aggregate.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_aggregate.rs @@ -89,7 +89,11 @@ impl Rule for RulePushDownFilterAggregate { if predicate_used_columns.is_subset(&aggregate_child_prop.output_columns) && predicate_used_columns.is_subset(&aggregate_group_columns) { - pushed_down_predicates.push(predicate); + pushed_down_predicates.push(predicate.clone()); + // Keep full remaining_predicates cause grouping_sets exists + if aggregate.grouping_sets.is_some() { + remaining_predicates.push(predicate); + } } else { remaining_predicates.push(predicate) } @@ -98,6 +102,7 @@ impl Rule for RulePushDownFilterAggregate { let pushed_down_filter = Filter { predicates: pushed_down_predicates, }; + let mut result = if remaining_predicates.is_empty() { SExpr::create_unary( Arc::new(aggregate.into()), diff --git a/src/tests/sqlsmith/src/sql_gen/query.rs b/src/tests/sqlsmith/src/sql_gen/query.rs index 9ef1f0b68c664..c931a8cbd2e0c 100644 --- a/src/tests/sqlsmith/src/sql_gen/query.rs +++ b/src/tests/sqlsmith/src/sql_gen/query.rs @@ -236,6 +236,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { orders.push(order_by_expr); } } + _ => unimplemented!(), } } else { for _ in 0..order_nums { @@ -351,12 +352,10 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { groupby_items.push(groupby_item); } - match self.rng.gen_range(0..=4) { + match self.rng.gen_range(0..=2) { 0 => Some(GroupBy::Normal(groupby_items)), 1 => Some(GroupBy::All), 2 => Some(GroupBy::GroupingSets(vec![groupby_items])), - 3 => Some(GroupBy::Cube(groupby_items)), - 4 => Some(GroupBy::Rollup(groupby_items)), _ => unreachable!(), } } @@ -370,9 +369,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { }; match group_by { - Some(GroupBy::Normal(group_by)) - | Some(GroupBy::Cube(group_by)) - | Some(GroupBy::Rollup(group_by)) => { + Some(GroupBy::Normal(group_by)) => { let ty = self.gen_data_type(); let agg_expr = self.gen_agg_func(&ty); targets.push(SelectTarget::AliasedExpr { @@ -409,7 +406,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { alias: None, })); } - None => { + _ => { let select_num = self.rng.gen_range(1..=7); for _ in 0..select_num { let ty = self.gen_data_type(); diff --git a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test index 5d68e422f37d0..b2e84947356b4 100644 --- a/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test +++ b/tests/sqllogictests/suites/base/03_common/03_0003_select_group_by.test @@ -142,7 +142,7 @@ statement ok CREATE TABLE IF NOT EXISTS t_variant(id Int null, var Variant null) Engine = Fuse statement ok -INSERT INTO t_variant VALUES(1, parse_json('{"k":"v"}')), (2, parse_json('{"k":"v"}')), (3, parse_json('"abcd"')), (4, parse_json('"abcd"')), (5, parse_json('12')), (6, parse_json('12')), (7, parse_json('[1,2,3]')), (8, parse_json('[1,2,3]')) +INSERT INTO t_variant VALUES(1, parse_json('{"k":"v"}')), (2, parse_json('{"k":"v"}')), (3, parse_json('"abcd"')), (4, parse_json('"abcd"')), (5, parse_json('12')), (6, parse_json('12')), (7, parse_json('[1,2,3]')), (8, parse_json('[1,2,3]')) query IIT SELECT max(id) as n, min(id), var FROM t_variant GROUP BY var ORDER BY n ASC @@ -159,7 +159,7 @@ statement ok CREATE TABLE IF NOT EXISTS t_array(id Int null, arr Array(Int32)) Engine = Fuse statement ok -INSERT INTO t_array VALUES(1, []), (2, []), (3, [1,2,3]), (4, [1,2,3]), (5, [4,5,6]), (6, [4,5,6]) +INSERT INTO t_array VALUES(1, []), (2, []), (3, [1,2,3]), (4, [1,2,3]), (5, [4,5,6]), (6, [4,5,6]) query I select count() from numbers(10) group by 'ab' @@ -300,3 +300,91 @@ select sum(number + 3 ), number % 3 from numbers(10) group by sum(number + 3 ) statement error (?s)1065.*GROUP BY items can't contain aggregate functions or window functions select sum(number + 3 ), number % 3 from numbers(10) group by 1, 2; + + +# test grouping sets, rollup + +statement ok +CREATE OR REPLACE TABLE sales ( + brand VARCHAR NOT NULL, + segment VARCHAR NOT NULL, + quantity INT NOT NULL +); + +statement ok +INSERT INTO sales (brand, segment, quantity) +VALUES + ('ABC', 'Premium', 100), + ('ABC', 'Basic', 200), + ('XYZ', 'Premium', 100), + ('XYZ', 'Basic', 300); + +query ITTI +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, GROUPING SETS(segment, quantity) order by 1,2,3; +---- +100 ABC NULL 100 +100 XYZ NULL 100 +200 ABC NULL 200 +300 XYZ NULL 300 +NULL ABC Basic 200 +NULL ABC Premium 100 +NULL XYZ Basic 300 +NULL XYZ Premium 100 + +query ITTI +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(segment, quantity) order by 1,2,3; +---- +100 ABC Premium 100 +100 XYZ Premium 100 +200 ABC Basic 200 +300 XYZ Basic 300 +NULL ABC Basic 200 +NULL ABC Premium 100 +NULL ABC NULL 300 +NULL XYZ Basic 300 +NULL XYZ Premium 100 +NULL XYZ NULL 400 + +query ITTI +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity) , segment order by 1,2,3; +---- +100 ABC Premium 100 +100 XYZ Premium 100 +200 ABC Basic 200 +300 XYZ Basic 300 +NULL ABC Basic 200 +NULL ABC Premium 100 +NULL XYZ Basic 300 +NULL XYZ Premium 100 + + + + +## results are deduplicated in grouping sets, the results are not standard but acceptable. +query ITTI +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity), GROUPING sets(brand, segment, quantity) order by 1,2,3; +---- +100 ABC Premium 100 +100 ABC NULL 100 +100 XYZ Premium 100 +100 XYZ NULL 100 +200 ABC Basic 200 +200 ABC NULL 200 +300 XYZ Basic 300 +300 XYZ NULL 300 +NULL ABC Basic 200 +NULL ABC Premium 100 +NULL ABC NULL 300 +NULL XYZ Basic 300 +NULL XYZ Premium 100 +NULL XYZ NULL 400 + +## filter push down into grouping sets +query ITTI +select * from (SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY rollup( quantity, brand, segment) ) e where e.segment = 'Basic' order by quantity desc; +---- +300 XYZ Basic 300 +200 ABC Basic 200 + +statement error +SELECT quantity, brand, segment, SUM (quantity) FROM sales GROUP BY brand, rollup(quantity), all;