From 77203534e5e119030ca6b2ed8ff1b8e639abe34c Mon Sep 17 00:00:00 2001 From: baishen Date: Mon, 20 Jan 2025 12:56:49 +0800 Subject: [PATCH] feat(ddl): support alter table cast tuple type column (#17310) * feat(ddl): support alter table cast tuple type column * fix * fix tests * fix tests * rename `array_tuple` as `arrays_zip` * fix tests --- src/query/functions/src/scalars/array.rs | 127 ++++++++++++++ src/query/functions/tests/it/scalars/array.rs | 8 + .../tests/it/scalars/testdata/array.txt | 35 ++++ .../it/scalars/testdata/function_list.txt | 1 + .../interpreter_table_modify_column.rs | 166 ++++++++++++++++-- .../base/05_ddl/05_0003_ddl_alter_table.test | 32 +++- .../functions/02_0015_function_tuples.test | 5 + .../functions/02_0061_function_array.test | 22 +++ 8 files changed, 378 insertions(+), 18 deletions(-) diff --git a/src/query/functions/src/scalars/array.rs b/src/query/functions/src/scalars/array.rs index 8d602a1663a68..c868b0456d0a2 100644 --- a/src/query/functions/src/scalars/array.rs +++ b/src/query/functions/src/scalars/array.rs @@ -24,6 +24,7 @@ use databend_common_expression::types::number::SimpleDomain; use databend_common_expression::types::number::UInt64Type; use databend_common_expression::types::AnyType; use databend_common_expression::types::ArgType; +use databend_common_expression::types::ArrayColumn; use databend_common_expression::types::ArrayType; use databend_common_expression::types::BooleanType; use databend_common_expression::types::DataType; @@ -158,6 +159,132 @@ pub fn register(registry: &mut FunctionRegistry) { })) }); + // Returns a merged array of tuples in which the nth tuple contains all nth values of input arrays. + registry.register_function_factory("arrays_zip", |_, args_type| { + if args_type.is_empty() { + return None; + } + let args_type = args_type.to_vec(); + + let inner_types: Vec = args_type + .iter() + .map(|arg_type| { + let is_nullable = arg_type.is_nullable(); + match arg_type.remove_nullable() { + DataType::Array(box inner_type) => { + if is_nullable { + inner_type.wrap_nullable() + } else { + inner_type.clone() + } + } + _ => arg_type.clone(), + } + }) + .collect(); + let return_type = DataType::Array(Box::new(DataType::Tuple(inner_types.clone()))); + Some(Arc::new(Function { + signature: FunctionSignature { + name: "arrays_zip".to_string(), + args_type: args_type.clone(), + return_type, + }, + eval: FunctionEval::Scalar { + calc_domain: Box::new(|_, args_domain| { + let inner_domains = args_domain + .iter() + .map(|arg_domain| match arg_domain { + Domain::Nullable(nullable_domain) => match &nullable_domain.value { + Some(box Domain::Array(Some(inner_domain))) => { + Domain::Nullable(NullableDomain { + has_null: nullable_domain.has_null, + value: Some(Box::new(*inner_domain.clone())), + }) + } + _ => Domain::Nullable(nullable_domain.clone()), + }, + Domain::Array(Some(box inner_domain)) => inner_domain.clone(), + _ => arg_domain.clone(), + }) + .collect(); + FunctionDomain::Domain(Domain::Array(Some(Box::new(Domain::Tuple( + inner_domains, + ))))) + }), + eval: Box::new(move |args, ctx| { + let len = args.iter().find_map(|arg| match arg { + Value::Column(col) => Some(col.len()), + _ => None, + }); + + let mut offset = 0; + let mut offsets = Vec::new(); + offsets.push(0); + let tuple_type = DataType::Tuple(inner_types.clone()); + let mut builder = ColumnBuilder::with_capacity(&tuple_type, 0); + for i in 0..len.unwrap_or(1) { + let mut is_diff_len = false; + let mut array_len = None; + for arg in args { + let value = unsafe { arg.index_unchecked(i) }; + if let ScalarRef::Array(col) = value { + if let Some(array_len) = array_len { + if array_len != col.len() { + is_diff_len = true; + let err = format!( + "array length must be equal, but got {} and {}", + array_len, + col.len() + ); + ctx.set_error(builder.len(), err); + offsets.push(offset); + break; + } + } else { + array_len = Some(col.len()); + } + } + } + if is_diff_len { + continue; + } + let array_len = array_len.unwrap_or(1); + for j in 0..array_len { + let mut tuple_values = Vec::with_capacity(args.len()); + for arg in args { + let value = unsafe { arg.index_unchecked(i) }; + match value { + ScalarRef::Array(col) => { + let tuple_value = unsafe { col.index_unchecked(j) }; + tuple_values.push(tuple_value.to_owned()); + } + _ => { + tuple_values.push(value.to_owned()); + } + } + } + let tuple_value = Scalar::Tuple(tuple_values); + builder.push(tuple_value.as_ref()); + } + offset += array_len as u64; + offsets.push(offset); + } + + match len { + Some(_) => { + let array_column = ArrayColumn { + values: builder.build(), + offsets: offsets.into(), + }; + Value::Column(Column::Array(Box::new(array_column))) + } + _ => Value::Scalar(Scalar::Array(builder.build())), + } + }), + }, + })) + }); + registry.register_1_arg::, _, _>( "length", |_, _| FunctionDomain::Domain(SimpleDomain { min: 0, max: 0 }), diff --git a/src/query/functions/tests/it/scalars/array.rs b/src/query/functions/tests/it/scalars/array.rs index 216c6d24c46fb..28a81fc699845 100644 --- a/src/query/functions/tests/it/scalars/array.rs +++ b/src/query/functions/tests/it/scalars/array.rs @@ -52,6 +52,7 @@ fn test_array() { test_array_kurtosis(file); test_array_skewness(file); test_array_sort(file); + test_arrays_zip(file); } fn test_create(file: &mut impl Write) { @@ -731,3 +732,10 @@ fn test_array_sort(file: &mut impl Write) { &[], ); } + +fn test_arrays_zip(file: &mut impl Write) { + run_ast(file, "arrays_zip(NULL, NULL)", &[]); + run_ast(file, "arrays_zip(1, 2, 'a')", &[]); + run_ast(file, "arrays_zip([1,2,3], ['a','b','c'], 10)", &[]); + run_ast(file, "arrays_zip([1,2,3], ['a','b'], 10)", &[]); +} diff --git a/src/query/functions/tests/it/scalars/testdata/array.txt b/src/query/functions/tests/it/scalars/testdata/array.txt index b22c7b293f83c..64d9fcc8dff43 100644 --- a/src/query/functions/tests/it/scalars/testdata/array.txt +++ b/src/query/functions/tests/it/scalars/testdata/array.txt @@ -2382,3 +2382,38 @@ output domain : [{0.0..=5.6} ∪ {NULL}] output : [5.6, 3.4, 2.2, 1.2, NULL, NULL] +ast : arrays_zip(NULL, NULL) +raw expr : arrays_zip(NULL, NULL) +checked expr : arrays_zip(NULL, NULL) +optimized expr : [(NULL, NULL)] +output type : Array(Tuple(NULL, NULL)) +output domain : [({NULL}, {NULL})] +output : [(NULL, NULL)] + + +ast : arrays_zip(1, 2, 'a') +raw expr : arrays_zip(1, 2, 'a') +checked expr : arrays_zip(1_u8, 2_u8, "a") +optimized expr : [(1, 2, 'a')] +output type : Array(Tuple(UInt8, UInt8, String)) +output domain : [({1..=1}, {2..=2}, {"a"..="a"})] +output : [(1, 2, 'a')] + + +ast : arrays_zip([1,2,3], ['a','b','c'], 10) +raw expr : arrays_zip(array(1, 2, 3), array('a', 'b', 'c'), 10) +checked expr : arrays_zip(array(1_u8, 2_u8, 3_u8), array("a", "b", "c"), 10_u8) +optimized expr : [(1, 'a', 10), (2, 'b', 10), (3, 'c', 10)] +output type : Array(Tuple(UInt8, String, UInt8)) +output domain : [({1..=3}, {"a"..="c"}, {10..=10})] +output : [(1, 'a', 10), (2, 'b', 10), (3, 'c', 10)] + + +error: + --> SQL:1:1 + | +1 | arrays_zip([1,2,3], ['a','b'], 10) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ array length must be equal, but got 3 and 2 while evaluating function `arrays_zip([1, 2, 3], ['a', 'b'], 10)` in expr `arrays_zip(array(1, 2, 3), array('a', 'b'), 10)` + + + diff --git a/src/query/functions/tests/it/scalars/testdata/function_list.txt b/src/query/functions/tests/it/scalars/testdata/function_list.txt index a7ebd51716226..958c61ee3d99b 100644 --- a/src/query/functions/tests/it/scalars/testdata/function_list.txt +++ b/src/query/functions/tests/it/scalars/testdata/function_list.txt @@ -187,6 +187,7 @@ Functions overloads: 1 array_unique(Array(Nothing) NULL) :: UInt64 NULL 2 array_unique(Array(T0)) :: UInt64 3 array_unique(Array(T0) NULL) :: UInt64 NULL +0 arrays_zip FACTORY 0 as_array(Variant) :: Variant NULL 1 as_array(Variant NULL) :: Variant NULL 0 as_boolean(Variant) :: Boolean NULL diff --git a/src/query/service/src/interpreters/interpreter_table_modify_column.rs b/src/query/service/src/interpreters/interpreter_table_modify_column.rs index d749c81bcc954..8c7314fb80a1c 100644 --- a/src/query/service/src/interpreters/interpreter_table_modify_column.rs +++ b/src/query/service/src/interpreters/interpreter_table_modify_column.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; use std::sync::Arc; use databend_common_catalog::catalog::Catalog; @@ -19,8 +20,10 @@ use databend_common_catalog::table::Table; use databend_common_catalog::table::TableExt; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::types::DataType; use databend_common_expression::ComputedExpr; use databend_common_expression::DataSchema; +use databend_common_expression::Scalar; use databend_common_expression::TableDataType; use databend_common_expression::TableField; use databend_common_expression::TableSchema; @@ -142,11 +145,19 @@ impl ModifyTableColumnInterpreter { let table_info = table.get_table_info(); let mut new_schema = schema.clone(); + let mut modify_field_indices = HashSet::new(); // first check default expr before lock table for (field, _comment) in field_and_comments { let column = &field.name.to_string(); let data_type = &field.data_type; - if let Some((i, _)) = schema.column_with_name(column) { + if let Some((i, old_field)) = schema.column_with_name(column) { + modify_field_indices.insert(i); + // if the field has different leaf column numbers, we need drop the old column + // and add a new one to generate new column id. otherwise, leaf column ids will conflict. + if field.data_type.num_leaf_columns() != old_field.data_type.num_leaf_columns() { + let _ = new_schema.drop_column(column); + let _ = new_schema.add_column(field, i); + } if let Some(default_expr) = &field.default_expr { let default_expr = default_expr.to_string(); new_schema.fields[i].data_type = data_type.clone(); @@ -336,29 +347,154 @@ impl ModifyTableColumnInterpreter { } // 1. construct sql for selecting data from old table - let mut sql = "select".to_string(); - schema + let query_fields = schema .fields() .iter() .enumerate() - .for_each(|(index, field)| { - if index != schema.fields().len() - 1 { - sql = format!("{} `{}`,", sql, field.name.clone()); + .map(|(index, field)| { + if modify_field_indices.contains(&index) { + let new_field = new_schema.field(index); + // If the column type is Tuple or Array(Tuple), the difference in the number of leaf columns may cause + // the auto cast to fail. + // We read the leaf column data, and then use build function to construct a new Tuple or Array(Tuple). + // Note: other nested types auto cast can still fail, we need a more general handling + // to solve this problem in the future. + match ( + field.data_type.remove_nullable(), + new_field.data_type.remove_nullable(), + ) { + ( + TableDataType::Tuple { + fields_name: old_fields_name, + .. + }, + TableDataType::Tuple { + fields_name: new_fields_name, + fields_type: new_fields_type, + }, + ) => { + let transform_funcs = new_fields_name + .iter() + .zip(new_fields_type.iter()) + .map(|(new_field_name, new_field_type)| { + match old_fields_name.iter().position(|n| n == new_field_name) { + Some(idx) => { + format!("`{}`.{}", field.name, idx + 1) + } + None => { + let new_data_type = DataType::from(new_field_type); + let default_value = + Scalar::default_value(&new_data_type); + format!("{default_value}") + } + } + }) + .collect::>() + .join(", "); + + format!( + "if(is_not_null(`{}`), tuple({}), NULL) AS {}", + field.name, transform_funcs, field.name + ) + } + ( + TableDataType::Array(box TableDataType::Tuple { + fields_name: old_fields_name, + .. + }), + TableDataType::Array(box TableDataType::Tuple { + fields_name: new_fields_name, + fields_type: new_fields_type, + }), + ) + | ( + TableDataType::Array(box TableDataType::Nullable( + box TableDataType::Tuple { + fields_name: old_fields_name, + .. + }, + )), + TableDataType::Array(box TableDataType::Tuple { + fields_name: new_fields_name, + fields_type: new_fields_type, + }), + ) + | ( + TableDataType::Array(box TableDataType::Tuple { + fields_name: old_fields_name, + .. + }), + TableDataType::Array(box TableDataType::Nullable( + box TableDataType::Tuple { + fields_name: new_fields_name, + fields_type: new_fields_type, + }, + )), + ) + | ( + TableDataType::Array(box TableDataType::Nullable( + box TableDataType::Tuple { + fields_name: old_fields_name, + .. + }, + )), + TableDataType::Array(box TableDataType::Nullable( + box TableDataType::Tuple { + fields_name: new_fields_name, + fields_type: new_fields_type, + }, + )), + ) => { + let transform_funcs = new_fields_name + .iter() + .zip(new_fields_type.iter()) + .map(|(new_field_name, new_field_type)| { + match old_fields_name.iter().position(|n| n == new_field_name) { + Some(idx) => { + format!( + "array_transform(`{}`, v -> v.{})", + field.name, + idx + 1 + ) + } + None => { + let new_data_type = DataType::from(new_field_type); + let default_value = + Scalar::default_value(&new_data_type); + format!("{default_value}") + } + } + }) + .collect::>() + .join(", "); + + format!( + "if(is_not_null(`{}`), arrays_zip({}), NULL) AS {}", + field.name, transform_funcs, field.name + ) + } + (_, _) => { + format!("`{}`", field.name) + } + } } else { - sql = format!( - "{} `{}` from `{}`.`{}`", - sql, - field.name.clone(), - self.plan.database, - self.plan.table - ); + format!("`{}`", field.name) } - }); + }) + .collect::>() + .join(", "); + + let sql = format!( + "SELECT {} FROM `{}`.`{}`", + query_fields, self.plan.database, self.plan.table + ); // 2. build plan by sql let mut planner = Planner::new(self.ctx.clone()); let (plan, _extras) = planner.plan_sql(&sql).await?; + let select_schema = plan.schema(); + // 3. build physical plan by plan let (select_plan, select_column_bindings) = match plan { Plan::Query { @@ -387,7 +523,7 @@ impl ModifyTableColumnInterpreter { plan_id: select_plan.get_id(), input: Box::new(select_plan), table_info: new_table.get_table_info().clone(), - select_schema: Arc::new(Arc::new(schema).into()), + select_schema, select_column_bindings, insert_schema: Arc::new(Arc::new(new_schema).into()), cast_needed: true, diff --git a/tests/sqllogictests/suites/base/05_ddl/05_0003_ddl_alter_table.test b/tests/sqllogictests/suites/base/05_ddl/05_0003_ddl_alter_table.test index 267532e8e5321..6126c94106d28 100644 --- a/tests/sqllogictests/suites/base/05_ddl/05_0003_ddl_alter_table.test +++ b/tests/sqllogictests/suites/base/05_ddl/05_0003_ddl_alter_table.test @@ -169,21 +169,47 @@ statement ok ALTER TABLE `05_0003_at_t4` MODIFY COLUMN c array(string) null statement ok -ALTER TABLE `05_0003_at_t4` MODIFY COLUMN d tuple(string, string) null +ALTER TABLE `05_0003_at_t4` MODIFY COLUMN d tuple(string, string, int) null query TT SHOW CREATE TABLE `05_0003_at_t4` ---- -05_0003_at_t4 CREATE TABLE "05_0003_at_t4" ( a VARCHAR NOT NULL, b VARCHAR NULL, c ARRAY(VARCHAR NULL) NULL, d TUPLE(1 VARCHAR NULL, 2 VARCHAR NULL) NULL ) ENGINE=FUSE +05_0003_at_t4 CREATE TABLE "05_0003_at_t4" ( a VARCHAR NOT NULL, b VARCHAR NULL, c ARRAY(VARCHAR NULL) NULL, d TUPLE(1 VARCHAR NULL, 2 VARCHAR NULL, 3 INT NULL) NULL ) ENGINE=FUSE query TTTT SELECT * FROM `05_0003_at_t4` ---- -a b ['c1','c2'] ('d1','d2') +a b ['c1','c2'] ('d1','d2',NULL) statement ok DROP TABLE IF EXISTS `05_0003_at_t4` +statement ok +CREATE OR REPLACE TABLE "05_0003_at_t5" ( a int, b array(tuple(int, int, string))) ENGINE=FUSE + +statement ok +INSERT INTO `05_0003_at_t5` VALUES(1, null),(2, [(1,2,'x'),(3,4,'y')]),(3, [(5,null,null),(6,null,'z')]); + +query IT +SELECT * FROM `05_0003_at_t5` +---- +1 NULL +2 [(1,2,'x'),(3,4,'y')] +3 [(5,NULL,NULL),(6,NULL,'z')] + +statement ok +ALTER TABLE `05_0003_at_t5` MODIFY COLUMN b array(tuple(int, float64, string, string)); + +query IT +SELECT * FROM `05_0003_at_t5` +---- +1 NULL +2 [(1,2.0,'x',NULL),(3,4.0,'y',NULL)] +3 [(5,NULL,NULL,NULL),(6,NULL,'z',NULL)] + +statement ok +DROP TABLE IF EXISTS `05_0003_at_t5` + statement ok drop table if exists t; diff --git a/tests/sqllogictests/suites/query/functions/02_0015_function_tuples.test b/tests/sqllogictests/suites/query/functions/02_0015_function_tuples.test index b3827b807d789..63460207ac009 100644 --- a/tests/sqllogictests/suites/query/functions/02_0015_function_tuples.test +++ b/tests/sqllogictests/suites/query/functions/02_0015_function_tuples.test @@ -1,3 +1,8 @@ statement ok SELECT (1, 'a', NULL, to_date(18869), (2.1, to_datetime(1630320462000000))) +query T +SELECT tuple([1,2,3], ['a','b','c'], 10); +---- +([1,2,3],['a','b','c'],10) + diff --git a/tests/sqllogictests/suites/query/functions/02_0061_function_array.test b/tests/sqllogictests/suites/query/functions/02_0061_function_array.test index 701fd718a8ea5..40cb07f5a50a2 100644 --- a/tests/sqllogictests/suites/query/functions/02_0061_function_array.test +++ b/tests/sqllogictests/suites/query/functions/02_0061_function_array.test @@ -398,8 +398,30 @@ select json_array_reduce([1,2,3,4]::Variant, (x, y) -> 3 + x + y), json_array_tr ---- 19 [] +query T +SELECT arrays_zip(1, 'a', null); +---- +[(1,'a',NULL)] + +query T +SELECT arrays_zip([1,2,3], ['a','b','c'], 10); +---- +[(1,'a',10),(2,'b',10),(3,'c',10)] + +statement error 1006 +SELECT arrays_zip([1,2,3], ['a','b'], 10); + +query T +SELECT arrays_zip(col1, col2) FROM t3; +---- +[(1,2),(2,2),(3,2)] +[(4,NULL),(5,NULL)] +[(NULL,4)] +[(7,5),(8,5)] + statement ok USE default statement ok DROP DATABASE array_func_test +