Skip to content

Commit

Permalink
feat(functions): add new function: map_pick (#15573)
Browse files Browse the repository at this point in the history
* feat(functions): add new function: map_pick

* feat(functions): add factory function for map_pick

* feat(functions): add more args_type check

* feat(functions): add args arrayType

* fix

* fix

* fix tests

* fix tests

---------

Co-authored-by: baishen <[email protected]>
  • Loading branch information
hanxuanliang and b41sh authored Nov 5, 2024
1 parent 31b7ceb commit 094f72d
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 56 deletions.
239 changes: 189 additions & 50 deletions src/query/functions/src/scalars/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ pub fn register(registry: &mut FunctionRegistry) {
vectorize_with_builder_2_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<1>>, MapType<GenericType<0>, GenericType<1>>>(
|keys, vals, output, ctx| {
let key_type = &ctx.generics[0];
if !key_type.is_boolean()
&& !key_type.is_string()
&& !key_type.is_numeric()
&& !key_type.is_decimal()
&& !key_type.is_date_or_date_time() {
if !check_valid_map_key_type(key_type) {
ctx.set_error(output.len(), format!("map keys can not be {}", key_type));
} else if keys.len() != vals.len() {
ctx.set_error(output.len(), format!(
Expand Down Expand Up @@ -241,43 +237,7 @@ pub fn register(registry: &mut FunctionRegistry) {
);

registry.register_function_factory("map_delete", |_, args_type| {
if args_type.len() < 2 {
return None;
}

let map_key_type = match args_type[0].remove_nullable() {
DataType::Map(box DataType::Tuple(type_tuple)) if type_tuple.len() == 2 => {
Some(type_tuple[0].clone())
}
DataType::EmptyMap => None,
_ => return None,
};

if let Some(map_key_type) = map_key_type {
for arg_type in args_type.iter().skip(1) {
if arg_type != &map_key_type {
return None;
}
}
} else {
let key_type = &args_type[1];
if !key_type.is_boolean()
&& !key_type.is_string()
&& !key_type.is_numeric()
&& !key_type.is_decimal()
&& !key_type.is_date_or_date_time()
{
return None;
}
for arg_type in args_type.iter().skip(2) {
if arg_type != key_type {
return None;
}
}
}

let return_type = args_type[0].clone();

let return_type = check_map_arg_types(args_type)?;
Some(Arc::new(Function {
signature: FunctionSignature {
name: "map_delete".to_string(),
Expand All @@ -297,31 +257,47 @@ pub fn register(registry: &mut FunctionRegistry) {
let mut output_map_builder =
ColumnBuilder::with_capacity(&return_type, input_length.unwrap_or(1));

let mut delete_key_list = HashSet::new();
for idx in 0..(input_length.unwrap_or(1)) {
let input_map_sref = match &args[0] {
let input_map = match &args[0] {
ValueRef::Scalar(map) => map.clone(),
ValueRef::Column(map) => unsafe { map.index_unchecked(idx) },
};

match &input_map_sref {
match &input_map {
ScalarRef::Null | ScalarRef::EmptyMap => {
output_map_builder.push_default();
}
ScalarRef::Map(col) => {
let mut delete_key_list = HashSet::new();

delete_key_list.clear();
for input_key_item in args.iter().skip(1) {
let input_key = match &input_key_item {
ValueRef::Scalar(scalar) => scalar.clone(),
ValueRef::Column(col) => unsafe {
col.index_unchecked(idx)
},
};

delete_key_list.insert(input_key.to_owned());
match input_key {
ScalarRef::EmptyArray | ScalarRef::Null => {}
ScalarRef::Array(arr_col) => {
for arr_key in arr_col.iter() {
if arr_key == ScalarRef::Null {
continue;
}
delete_key_list.insert(arr_key.to_owned());
}
}
_ => {
delete_key_list.insert(input_key.to_owned());
}
}
}
if delete_key_list.is_empty() {
output_map_builder.push(input_map);
continue;
}

let inner_builder_type = match input_map_sref.infer_data_type() {
let inner_builder_type = match input_map.infer_data_type() {
DataType::Map(box typ) => typ,
_ => unreachable!(),
};
Expand All @@ -330,7 +306,7 @@ pub fn register(registry: &mut FunctionRegistry) {
ColumnBuilder::with_capacity(&inner_builder_type, col.len());

let input_map: KvColumn<AnyType, AnyType> =
MapType::try_downcast_scalar(&input_map_sref).unwrap();
MapType::try_downcast_scalar(&input_map).unwrap();

input_map.iter().for_each(|(map_key, map_value)| {
if !delete_key_list.contains(&map_key.to_owned()) {
Expand Down Expand Up @@ -371,4 +347,167 @@ pub fn register(registry: &mut FunctionRegistry) {
.any(|(k, _)| k == key)
},
);

registry.register_function_factory("map_pick", |_, args_type: &[DataType]| {
let return_type = check_map_arg_types(args_type)?;
Some(Arc::new(Function {
signature: FunctionSignature {
name: "map_pick".to_string(),
args_type: args_type.to_vec(),
return_type: args_type[0].clone(),
},
eval: FunctionEval::Scalar {
calc_domain: Box::new(|_, args_domain| {
FunctionDomain::Domain(args_domain[0].clone())
}),
eval: Box::new(move |args, _ctx| {
let input_length = args.iter().find_map(|arg| match arg {
ValueRef::Column(col) => Some(col.len()),
_ => None,
});

let mut output_map_builder =
ColumnBuilder::with_capacity(&return_type, input_length.unwrap_or(1));

let mut pick_key_list = HashSet::new();
for idx in 0..(input_length.unwrap_or(1)) {
let input_map = match &args[0] {
ValueRef::Scalar(map) => map.clone(),
ValueRef::Column(map) => unsafe { map.index_unchecked(idx) },
};

match &input_map {
ScalarRef::Null | ScalarRef::EmptyMap => {
output_map_builder.push_default();
}
ScalarRef::Map(col) => {
pick_key_list.clear();
for input_key_item in args.iter().skip(1) {
let input_key = match &input_key_item {
ValueRef::Scalar(scalar) => scalar.clone(),
ValueRef::Column(col) => unsafe {
col.index_unchecked(idx)
},
};
match input_key {
ScalarRef::EmptyArray | ScalarRef::Null => {}
ScalarRef::Array(arr_col) => {
for arr_key in arr_col.iter() {
if arr_key == ScalarRef::Null {
continue;
}
pick_key_list.insert(arr_key.to_owned());
}
}
_ => {
pick_key_list.insert(input_key.to_owned());
}
}
}
if pick_key_list.is_empty() {
output_map_builder.push_default();
continue;
}

let inner_builder_type = match input_map.infer_data_type() {
DataType::Map(box typ) => typ,
_ => unreachable!(),
};

let mut filtered_kv_builder =
ColumnBuilder::with_capacity(&inner_builder_type, col.len());

let input_map: KvColumn<AnyType, AnyType> =
MapType::try_downcast_scalar(&input_map).unwrap();

input_map.iter().for_each(|(map_key, map_value)| {
if pick_key_list.contains(&map_key.to_owned()) {
filtered_kv_builder.push(ScalarRef::Tuple(vec![
map_key.clone(),
map_value.clone(),
]));
}
});
output_map_builder
.push(ScalarRef::Map(filtered_kv_builder.build()));
}
_ => unreachable!(),
}
}

match input_length {
Some(_) => Value::Column(output_map_builder.build()),
None => Value::Scalar(output_map_builder.build_scalar()),
}
}),
},
}))
});
}

// Check map function arg types
// 1. The first arg must be a Map or EmptyMap.
// 2. The second arg can be an Array or EmptyArray.
// 3. Multiple args with same key type is also valid.
fn check_map_arg_types(args_type: &[DataType]) -> Option<DataType> {
if args_type.len() < 2 {
return None;
}

let map_key_type = match args_type[0].remove_nullable() {
DataType::Map(box DataType::Tuple(type_tuple)) if type_tuple.len() == 2 => {
Some(type_tuple[0].clone())
}
DataType::EmptyMap => None,
_ => return None,
};

// the second argument can be an array of keys.
let (is_array, array_key_type) = match args_type[1].remove_nullable() {
DataType::Array(box key_type) => (true, Some(key_type.remove_nullable())),
DataType::EmptyArray => (true, None),
_ => (false, None),
};
if is_array && args_type.len() != 2 {
return None;
}
if let Some(map_key_type) = map_key_type {
if is_array {
if let Some(array_key_type) = array_key_type {
if array_key_type != DataType::Null && array_key_type != map_key_type {
return None;
}
}
} else {
for arg_type in args_type.iter().skip(1) {
let arg_type = arg_type.remove_nullable();
if arg_type != DataType::Null && arg_type != map_key_type {
return None;
}
}
}
} else if is_array {
if let Some(array_key_type) = array_key_type {
if array_key_type != DataType::Null && !check_valid_map_key_type(&array_key_type) {
return None;
}
}
} else {
for arg_type in args_type.iter().skip(1) {
let arg_type = arg_type.remove_nullable();
if arg_type != DataType::Null && !check_valid_map_key_type(&arg_type) {
return None;
}
}
}
let return_type = args_type[0].clone();
Some(return_type)
}

fn check_valid_map_key_type(key_type: &DataType) -> bool {
key_type.is_boolean()
|| key_type.is_string()
|| key_type.is_numeric()
|| key_type.is_decimal()
|| key_type.is_date_or_date_time()
}
38 changes: 38 additions & 0 deletions src/query/functions/tests/it/scalars/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ fn test_map() {
test_map_cat(file);
test_map_delete(file);
test_map_contains_key(file);
test_map_pick(file);
}

fn test_map_cat(file: &mut impl Write) {
Expand Down Expand Up @@ -296,6 +297,11 @@ fn test_map_delete(file: &mut impl Write) {
"map_delete({'k1': 'v1', 'k2': 'v2', 'k3': 'v3', 'k4': 'v4'}, 'k3', 'k2')",
&[],
);
run_ast(
file,
"map_delete({'k1': 'v1', 'k2': 'v2', 'k3': 'v3', 'k4': 'v4'}, ['k3', 'k2'])",
&[],
);

// Deleting keys from a nested map
let columns = [
Expand Down Expand Up @@ -381,3 +387,35 @@ fn test_map_delete(file: &mut impl Write) {
&columns,
);
}

fn test_map_pick(file: &mut impl Write) {
run_ast(file, "map_pick({'a':1,'b':2,'c':3}, 'a', 'b')", &[]);
run_ast(file, "map_pick({'a':1,'b':2,'c':3}, ['a', 'b'])", &[]);
run_ast(file, "map_pick({'a':1,'b':2,'c':3}, [])", &[]);
run_ast(file, "map_pick({1:'a',2:'b',3:'c'}, 1, 3)", &[]);
run_ast(file, "map_pick({}, 'a', 'b')", &[]);
run_ast(file, "map_pick({}, [])", &[]);

let columns = [
("a_col", StringType::from_data(vec!["a", "b", "c"])),
("b_col", StringType::from_data(vec!["d", "e", "f"])),
("c_col", StringType::from_data(vec!["x", "y", "z"])),
(
"d_col",
StringType::from_data_with_validity(vec!["v1", "v2", "v3"], vec![true, true, true]),
),
(
"e_col",
StringType::from_data_with_validity(vec!["v4", "v5", ""], vec![true, true, false]),
),
(
"f_col",
StringType::from_data_with_validity(vec!["v6", "", "v7"], vec![true, false, true]),
),
];
run_ast(
file,
"map_pick(map([a_col, b_col, c_col], [d_col, e_col, f_col]), 'a', 'b')",
&columns,
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2549,6 +2549,7 @@ Functions overloads:
0 map_keys(Map(Nothing)) :: Array(Nothing)
1 map_keys(Map(T0, T1)) :: Array(T0)
2 map_keys(Map(T0, T1) NULL) :: Array(T0) NULL
0 map_pick FACTORY
0 map_size(Map(Nothing)) :: UInt8
1 map_size(Map(T0, T1)) :: UInt64
2 map_size(Map(T0, T1) NULL) :: UInt64 NULL
Expand Down
Loading

0 comments on commit 094f72d

Please sign in to comment.