Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix get_type for higher-order array functions #13756

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ pub enum ArrayFunctionSignature {
/// The function takes a single argument that must be a List/LargeList/FixedSizeList
/// or something that can be coerced to one of those types.
Array,
/// A function takes a single argument that must be a List/LargeList/FixedSizeList
/// which gets coerced to List, with element type recursively coerced to List too if it is list-like.
RecursiveArray,
/// Specialized Signature for MapArray
/// The function takes a single argument that must be a MapArray
MapArray,
Expand All @@ -198,6 +201,9 @@ impl std::fmt::Display for ArrayFunctionSignature {
ArrayFunctionSignature::Array => {
write!(f, "array")
}
ArrayFunctionSignature::RecursiveArray => {
write!(f, "recursive_array")
}
ArrayFunctionSignature::MapArray => {
write!(f, "map_array")
}
Expand Down
21 changes: 20 additions & 1 deletion datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err,
types::{LogicalType, NativeType},
utils::{coerced_fixed_size_list_to_list, list_ndims},
utils::list_ndims,
Result,
};
use datafusion_expr_common::{
Expand Down Expand Up @@ -414,7 +415,18 @@ fn get_valid_types(
_ => Ok(vec![vec![]]),
}
}

fn array(array_type: &DataType) -> Option<DataType> {
match array_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this says that if the type is a list, keep the type, but if the type is large list / fixed size list then take the field type?

Why doesn't it also take the field type for List 🤔 ? (Aka it doesn't make sense to me that List is treated differently than LargeList and FixedSizeList

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for backwards compat i should keep LargeList so it stays LargeList, will push shortly

Aka it doesn't make sense to me that List is treated differently than LargeList and FixedSizeList

not my invention, it was like this before.
i think the intention is "converge List, LL and FSL into one type... or maybe two types... to keep UDF impl simpler".

i am not attached to this approach, but i think code may be reliant on that

DataType::List(_) => Some(array_type.clone()),
DataType::LargeList(field) | DataType::FixedSizeList(field, _) => {
Some(DataType::List(Arc::clone(field)))
}
_ => None,
}
}

fn recursive_array(array_type: &DataType) -> Option<DataType> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extend the existing array function for nested array instead of creating another signature for nested array

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to do this, please advise!
But this function should go away with #13757.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this function should go away with #13757.

I don't understand -- if the goal is to remove recursive flattening, should we be adding new code to support it 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pre-existing array signature implied recursively array-infication (replacing FixedLengthList with List, recursively), didn't imply flattening.

the recursive type normalization matters for flatten only, cause it (currently) operates recursively and otherwise would need to gain code to handle FixedLengthList inputs

the recursive array-ification was useless for other array functions and was made non-recursive.
to compensate for this change, new RecursiveArray signature was added for flatten case.

match array_type {
DataType::List(_)
| DataType::LargeList(_)
Expand Down Expand Up @@ -653,6 +665,13 @@ fn get_valid_types(
array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::RecursiveArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
}
recursive_array(&current_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::MapArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
Expand Down
81 changes: 81 additions & 0 deletions datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,3 +993,84 @@ where
let data = mutable.freeze();
Ok(arrow::array::make_array(data))
}

#[cfg(test)]
mod tests {
use super::array_element_udf;
use arrow_schema::{DataType, Field};
use datafusion_common::{Column, DFSchema, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{cast, Expr, ExprSchemable};
use std::collections::HashMap;

#[test]
fn test_array_element_return_type() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can add tests in slt file that cover the array signature test cases, so we can avoid creating rust test here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rust test allows explicitly exercising various ways of getting expression type.
Before i wrote it, I wasn't even sure whether it's a bug or a feature.

I can add slt test, how would it look like?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try to write some slt regression tests, but i couldn't expose the bug. Yet, the unit tests proves the bug exists.
I trust you have a better intuition how signature related bug can be exposed in SLT. Please advise.

let complex_type = DataType::FixedSizeList(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I change this complex type to DataType::List the test passes 🤔

        let complex_type = DataType::List(
            Field::new("some_arbitrary_test_field", DataType::Int32, false).into(),
        );

It also passes when complex_type is a Struct

        let complex_type = DataType::Struct(Fields::from(vec![
            Arc::new(Field::new("some_arbitrary_test_field", DataType::Int32, false)),
        ]));

It seems like there is something about FixedSizeList that is causing issues to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, when I remove this line in expr schema the test passes (with FixedSizedList):

diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index 3317deafb..50aeb222f 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -152,6 +152,7 @@ impl ExprSchemable for Expr {
                     .map(|e| e.get_type(schema))
                     .collect::<Result<Vec<_>>>()?;

+
                 // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
                 let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
                     .map_err(|err| {
@@ -168,7 +169,7 @@ impl ExprSchemable for Expr {

                 // Perform additional function arguments validation (due to limited
                 // expressiveness of `TypeSignature`), then infer return type
-                Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
+                Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?)
             }
             Expr::WindowFunction(window_function) => self
                 .data_type_and_nullable_with_window_function(schema, window_function)

Which basically says pass the input data types directly to the function call rather than calling data_types_with_scalar_udf first (which claims to type coercion)

Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)

🤔 this looks like it was added in Sep via 1b3608d (before that the input types were passed directly) 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem right to me that ExprSchema is coercing the arguments (implicitly) to me 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like there is something about FixedSizeList that is causing issues to me

correct, #13756 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, when I remove this line in expr schema the test passes (with FixedSizedList):

i did the same, basically removing this block

// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_data_types,
)
)
})?;

it's enough to fix the unit test in this PR
but other things start to fail

It doesn't seem right to me that ExprSchema is coercing the arguments (implicitly) to me 🤔

agreed

Copy link
Contributor

@jayzhan211 jayzhan211 Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the function arguments should already be of the right coerced type

I don't know the context of why we needed to apply coercion rules in the first place

The reason is because we can't guarantee the input is already coerced.

To determine the return type of a function for a given set of inputs, we follow these steps:

  1. Input Validation: Check if the number of inputs is correct and whether their types match the expected types.
  2. Type Coercion: If the input types don't match exactly, attempt to coerce them into compatible types.
  3. Return Type Decision: Once coercion is complete (if applicable), decide the return type based on the resulting input types.

That is why we have coercion in get_type for return_type. We can move out the coercion in get_type to ScalarFunction::new_udf

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we compute the return_type when the function is created, and get_type read the value.

I like the idea in principle.

It should be combined with a new ScalarUDFImpl sub-trait that doesn't have return type-related methods at all, since they are not to be used once the plan is constructed.

The reason is because we can't guarantee the input is already coerced.

in a logical plan we can.

My understanding is that coercing analyzer also calls the get_type functions.
It can be solved by changing how the coercing analyzer tracks its internal state.

But the real problem is that same types, the LogicalPlan & Expr, have two meanings: syntactic and semantic. So in the code we go back and forth about what should and what cannot be guaranteed for an Expr or LogicalPlan instance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the LogicalPlan & Expr, have two meanings: syntactic and semantic.

Is there example about the difference of this two, especially for function. For Expr::ScalarFunction, it has no difference in LogicalPlan, we don't do anything special, but I think this is what you don't expect. What should we have in LogicalPlan, Expr::ScalarFunction but with coerced input?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since they are not to be used once the plan is constructed.

Why get_type is not supposed to be available after plan is constructed from Expr.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there example about the difference of this two, especially for function.

the difference is more apparent for duplicate syntax (such is IS NULL vs IS UNKNOWN), syntax sugar (order by 1, order by all, select *)
for function call the difference is about function being resolved (typed and inputs coerced) or not.

since they are not to be used once the plan is constructed.

Why get_type is not supposed to be available after plan is constructed from Expr.

for a fully resolved logical plan it's fair question to ask what is the type of an expression (and this may or may not be O(1) available answer)

however, there is no point to ask a UDF what is its type, since we already asked it

think of this as engine and UDF being implemented by independent parties, with UDF being a contract layer.
you go over a contract layer when you have to (analysis time), but going over contract layer multiple times with the same question should be avoided.

Field::new("some_arbitrary_test_field", DataType::Int32, false).into(),
13,
);
let array_type =
DataType::List(Field::new_list_field(complex_type.clone(), true).into());
let index_type = DataType::Int64;

let schema = DFSchema::from_unqualified_fields(
vec![
Field::new("my_array", array_type.clone(), false),
Field::new("my_index", index_type.clone(), false),
]
.into(),
HashMap::default(),
)
.unwrap();

let udf = array_element_udf();

// ScalarUDFImpl::return_type
assert_eq!(
udf.return_type(&[array_type.clone(), index_type.clone()])
.unwrap(),
complex_type
);

// ScalarUDFImpl::return_type_from_exprs with typed exprs
assert_eq!(
udf.return_type_from_exprs(
&[
cast(Expr::Literal(ScalarValue::Null), array_type.clone()),
cast(Expr::Literal(ScalarValue::Null), index_type.clone()),
],
&schema,
&[array_type.clone(), index_type.clone()]
)
.unwrap(),
complex_type
);

// ScalarUDFImpl::return_type_from_exprs with exprs not carrying type
assert_eq!(
udf.return_type_from_exprs(
&[
Expr::Column(Column::new_unqualified("my_array")),
Expr::Column(Column::new_unqualified("my_index")),
],
&schema,
&[array_type.clone(), index_type.clone()]
)
.unwrap(),
complex_type
);

// Via ExprSchemable::get_type (e.g. SimplifyInfo)
let udf_expr = Expr::ScalarFunction(ScalarFunction {
func: array_element_udf(),
args: vec![
Expr::Column(Column::new_unqualified("my_array")),
Expr::Column(Column::new_unqualified("my_index")),
],
});
assert_eq!(
ExprSchemable::get_type(&udf_expr, &schema).unwrap(),
complex_type
);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't pass before the change. The assertions above did pass.

}
}
11 changes: 9 additions & 2 deletions datafusion/functions-nested/src/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use datafusion_common::cast::{
use datafusion_common::{exec_err, Result};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
};
use std::any::Any;
use std::sync::{Arc, OnceLock};
Expand Down Expand Up @@ -56,7 +57,13 @@ impl Default for Flatten {
impl Flatten {
pub fn new() -> Self {
Self {
signature: Signature::array(Volatility::Immutable),
signature: Signature {
// TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
findepi marked this conversation as resolved.
Show resolved Hide resolved
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::RecursiveArray,
),
volatility: Volatility::Immutable,
},
aliases: vec![],
}
}
Expand Down
Loading