Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get_type for higher-order array functions #13756

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ pub enum ArrayFunctionSignature {
/// The function takes a single argument that must be a List/LargeList/FixedSizeList
/// or something that can be coerced to one of those types.
Array,
/// A function takes a single argument that must be a List/LargeList/FixedSizeList
/// which gets coerced to List, with element type recursively coerced to List too if it is list-like.
RecursiveArray,
/// Specialized Signature for MapArray
/// The function takes a single argument that must be a MapArray
MapArray,
Expand All @@ -198,6 +201,9 @@ impl std::fmt::Display for ArrayFunctionSignature {
ArrayFunctionSignature::Array => {
write!(f, "array")
}
ArrayFunctionSignature::RecursiveArray => {
write!(f, "recursive_array")
}
ArrayFunctionSignature::MapArray => {
write!(f, "map_array")
}
Expand Down
19 changes: 18 additions & 1 deletion datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err,
types::{LogicalType, NativeType},
utils::{coerced_fixed_size_list_to_list, list_ndims},
utils::list_ndims,
Result,
};
use datafusion_expr_common::{
Expand Down Expand Up @@ -414,7 +415,16 @@ fn get_valid_types(
_ => Ok(vec![vec![]]),
}
}

fn array(array_type: &DataType) -> Option<DataType> {
match array_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this says that if the type is a list, keep the type, but if the type is large list / fixed size list then take the field type?

Why doesn't it also take the field type for List 🤔 ? (Aka it doesn't make sense to me that List is treated differently than LargeList and FixedSizeList

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for backwards compat i should keep LargeList so it stays LargeList, will push shortly

Aka it doesn't make sense to me that List is treated differently than LargeList and FixedSizeList

not my invention, it was like this before.
i think the intention is "converge List, LL and FSL into one type... or maybe two types... to keep UDF impl simpler".

i am not attached to this approach, but i think code may be reliant on that

DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
// Note array functions can often change the number of elements
// so convert from FixedSize --> variable
DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),

_ => None,
}
}

fn recursive_array(array_type: &DataType) -> Option<DataType> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extend the existing array function for nested array instead of creating another signature for nested array

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to do this, please advise!
But this function should go away with #13757.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this function should go away with #13757.

I don't understand -- if the goal is to remove recursive flattening, should we be adding new code to support it 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pre-existing array signature implied recursively array-infication (replacing FixedLengthList with List, recursively), didn't imply flattening.

the recursive type normalization matters for flatten only, cause it (currently) operates recursively and otherwise would need to gain code to handle FixedLengthList inputs

the recursive array-ification was useless for other array functions and was made non-recursive.
to compensate for this change, new RecursiveArray signature was added for flatten case.

match array_type {
DataType::List(_)
| DataType::LargeList(_)
Expand Down Expand Up @@ -653,6 +663,13 @@ fn get_valid_types(
array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::RecursiveArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
}
recursive_array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::MapArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
Expand Down
83 changes: 83 additions & 0 deletions datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,3 +993,86 @@ where
let data = mutable.freeze();
Ok(arrow::array::make_array(data))
}

#[cfg(test)]
mod tests {
use super::array_element_udf;
use arrow_schema::{DataType, Field};
use datafusion_common::{Column, DFSchema, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{cast, Expr, ExprSchemable};
use std::collections::HashMap;

// Regression test for https://github.com/apache/datafusion/issues/13755
#[test]
fn test_array_element_return_type_fixed_size_list() {
let fixed_size_list_type = DataType::FixedSizeList(
Field::new("some_arbitrary_test_field", DataType::Int32, false).into(),
13,
);
let array_type = DataType::List(
Field::new_list_field(fixed_size_list_type.clone(), true).into(),
);
let index_type = DataType::Int64;

let schema = DFSchema::from_unqualified_fields(
vec![
Field::new("my_array", array_type.clone(), false),
Field::new("my_index", index_type.clone(), false),
]
.into(),
HashMap::default(),
)
.unwrap();

let udf = array_element_udf();

// ScalarUDFImpl::return_type
assert_eq!(
udf.return_type(&[array_type.clone(), index_type.clone()])
.unwrap(),
fixed_size_list_type
);

// ScalarUDFImpl::return_type_from_exprs with typed exprs
assert_eq!(
udf.return_type_from_exprs(
&[
cast(Expr::Literal(ScalarValue::Null), array_type.clone()),
cast(Expr::Literal(ScalarValue::Null), index_type.clone()),
],
&schema,
&[array_type.clone(), index_type.clone()]
)
.unwrap(),
fixed_size_list_type
);

// ScalarUDFImpl::return_type_from_exprs with exprs not carrying type
assert_eq!(
udf.return_type_from_exprs(
&[
Expr::Column(Column::new_unqualified("my_array")),
Expr::Column(Column::new_unqualified("my_index")),
],
&schema,
&[array_type.clone(), index_type.clone()]
)
.unwrap(),
fixed_size_list_type
);

// Via ExprSchemable::get_type (e.g. SimplifyInfo)
let udf_expr = Expr::ScalarFunction(ScalarFunction {
func: array_element_udf(),
args: vec![
Expr::Column(Column::new_unqualified("my_array")),
Expr::Column(Column::new_unqualified("my_index")),
],
});
assert_eq!(
ExprSchemable::get_type(&udf_expr, &schema).unwrap(),
fixed_size_list_type
);
}
}
11 changes: 9 additions & 2 deletions datafusion/functions-nested/src/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use datafusion_common::cast::{
use datafusion_common::{exec_err, Result};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
};
use std::any::Any;
use std::sync::{Arc, OnceLock};
Expand Down Expand Up @@ -56,7 +57,13 @@ impl Default for Flatten {
impl Flatten {
pub fn new() -> Self {
Self {
signature: Signature::array(Volatility::Immutable),
signature: Signature {
// TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
findepi marked this conversation as resolved.
Show resolved Hide resolved
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::RecursiveArray,
),
volatility: Volatility::Immutable,
},
aliases: vec![],
}
}
Expand Down
Loading