From d321ba3cc31ba20823a7b6452899da070f528522 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 22 Mar 2024 14:00:11 -0400 Subject: [PATCH] Move trim functions (btrim, ltrim, rtrim) to datafusion_functions, make expr_fn API consistent (#9730) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Move trim functions to datafusion-functions * Doc updates for ltrim, rtrim and trim to reflect how they actually function. * Fixed struct name Trim -> BTrim --- .../tests/dataframe/dataframe_functions.rs | 4 +- datafusion/expr/src/built_in_function.rs | 27 -- datafusion/expr/src/expr_fn.rs | 21 -- datafusion/functions/Cargo.toml | 7 +- .../src/string/{trim.rs => btrim.rs} | 37 ++- datafusion/functions/src/string/common.rs | 265 +++++++++++++++ datafusion/functions/src/string/ltrim.rs | 77 +++++ datafusion/functions/src/string/mod.rs | 301 +++--------------- datafusion/functions/src/string/rtrim.rs | 77 +++++ .../functions/src/string/starts_with.rs | 3 +- datafusion/functions/src/string/to_hex.rs | 3 +- datafusion/functions/src/string/upper.rs | 5 +- datafusion/physical-expr/src/functions.rs | 187 ----------- .../physical-expr/src/string_expressions.rs | 94 +----- datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 9 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 28 +- datafusion/proto/src/logical_plan/to_proto.rs | 3 - datafusion/sql/src/expr/mod.rs | 53 +-- .../source/user-guide/sql/scalar_functions.md | 54 ++-- 21 files changed, 559 insertions(+), 714 deletions(-) rename datafusion/functions/src/string/{trim.rs => btrim.rs} (73%) create mode 100644 datafusion/functions/src/string/common.rs create mode 100644 datafusion/functions/src/string/ltrim.rs create mode 100644 datafusion/functions/src/string/rtrim.rs diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index cea701492910..6ebd64c9b628 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -367,7 +367,7 @@ async fn test_fn_lpad_with_string() -> Result<()> { #[tokio::test] async fn test_fn_ltrim() -> Result<()> { - let expr = ltrim(lit(" a b c ")); + let expr = ltrim(vec![lit(" a b c ")]); let expected = [ "+-----------------------------------------+", @@ -384,7 +384,7 @@ async fn test_fn_ltrim() -> Result<()> { #[tokio::test] async fn test_fn_ltrim_with_columns() -> Result<()> { - let expr = ltrim(col("a")); + let expr = ltrim(vec![col("a")]); let expected = [ "+---------------+", diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fffe2cf4c9c9..785965f6f693 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -107,8 +107,6 @@ pub enum BuiltinScalarFunction { Ascii, /// bit_length BitLength, - /// btrim - Btrim, /// character_length CharacterLength, /// chr @@ -127,8 +125,6 @@ pub enum BuiltinScalarFunction { Lpad, /// lower Lower, - /// ltrim - Ltrim, /// octet_length OctetLength, /// random @@ -143,8 +139,6 @@ pub enum BuiltinScalarFunction { Right, /// rpad Rpad, - /// rtrim - Rtrim, /// split_part SplitPart, /// strpos @@ -248,7 +242,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, - BuiltinScalarFunction::Btrim => Volatility::Immutable, BuiltinScalarFunction::CharacterLength => Volatility::Immutable, BuiltinScalarFunction::Chr => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, @@ -258,7 +251,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Left => Volatility::Immutable, BuiltinScalarFunction::Lpad => Volatility::Immutable, BuiltinScalarFunction::Lower => Volatility::Immutable, - BuiltinScalarFunction::Ltrim => Volatility::Immutable, BuiltinScalarFunction::OctetLength => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, BuiltinScalarFunction::Repeat => Volatility::Immutable, @@ -266,7 +258,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Reverse => Volatility::Immutable, BuiltinScalarFunction::Right => Volatility::Immutable, BuiltinScalarFunction::Rpad => Volatility::Immutable, - BuiltinScalarFunction::Rtrim => Volatility::Immutable, BuiltinScalarFunction::SplitPart => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, @@ -303,9 +294,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::BitLength => { utf8_to_int_type(&input_expr_types[0], "bit_length") } - BuiltinScalarFunction::Btrim => { - utf8_to_str_type(&input_expr_types[0], "btrim") - } BuiltinScalarFunction::CharacterLength => { utf8_to_int_type(&input_expr_types[0], "character_length") } @@ -325,9 +313,6 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "lower") } BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), - BuiltinScalarFunction::Ltrim => { - utf8_to_str_type(&input_expr_types[0], "ltrim") - } BuiltinScalarFunction::OctetLength => { utf8_to_int_type(&input_expr_types[0], "octet_length") } @@ -347,9 +332,6 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "right") } BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), - BuiltinScalarFunction::Rtrim => { - utf8_to_str_type(&input_expr_types[0], "rtrim") - } BuiltinScalarFunction::SplitPart => { utf8_to_str_type(&input_expr_types[0], "split_part") } @@ -456,12 +438,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Reverse => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::Btrim - | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim => Signature::one_of( - vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], - self.volatility(), - ), BuiltinScalarFunction::Chr => { Signature::uniform(1, vec![Int64], self.volatility()) } @@ -703,7 +679,6 @@ impl BuiltinScalarFunction { // string functions BuiltinScalarFunction::Ascii => &["ascii"], BuiltinScalarFunction::BitLength => &["bit_length"], - BuiltinScalarFunction::Btrim => &["btrim"], BuiltinScalarFunction::CharacterLength => { &["character_length", "char_length", "length"] } @@ -715,14 +690,12 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Left => &["left"], BuiltinScalarFunction::Lower => &["lower"], BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Ltrim => &["ltrim"], BuiltinScalarFunction::OctetLength => &["octet_length"], BuiltinScalarFunction::Repeat => &["repeat"], BuiltinScalarFunction::Replace => &["replace"], BuiltinScalarFunction::Reverse => &["reverse"], BuiltinScalarFunction::Right => &["right"], BuiltinScalarFunction::Rpad => &["rpad"], - BuiltinScalarFunction::Rtrim => &["rtrim"], BuiltinScalarFunction::SplitPart => &["split_part"], BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8667f631c507..a834ccab9d15 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -601,12 +601,6 @@ scalar_expr!( scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Lower, lower, string, "convert the string to lower case"); -scalar_expr!( - Ltrim, - ltrim, - string, - "removes all characters, spaces by default, from the beginning of a string" -); scalar_expr!( OctetLength, octet_length, @@ -617,12 +611,6 @@ scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `fro scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times"); scalar_expr!(Reverse, reverse, string, "reverses the `string`"); scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`"); -scalar_expr!( - Rtrim, - rtrim, - string, - "removes all characters, spaces by default, from the end of a string" -); scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); @@ -640,11 +628,6 @@ nary_scalar_expr!( rpad, "fill up a string to the length by appending the characters" ); -nary_scalar_expr!( - Btrim, - btrim, - "removes all characters, spaces by default, from both sides of a string" -); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c nary_scalar_expr!( @@ -1082,8 +1065,6 @@ mod test { test_scalar_expr!(Ascii, ascii, input); test_scalar_expr!(BitLength, bit_length, string); - test_nary_scalar_expr!(Btrim, btrim, string); - test_nary_scalar_expr!(Btrim, btrim, string, characters); test_scalar_expr!(CharacterLength, character_length, string); test_scalar_expr!(Chr, chr, string); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); @@ -1093,7 +1074,6 @@ mod test { test_scalar_expr!(Lower, lower, string); test_nary_scalar_expr!(Lpad, lpad, string, count); test_nary_scalar_expr!(Lpad, lpad, string, count, characters); - test_scalar_expr!(Ltrim, ltrim, string); test_scalar_expr!(OctetLength, octet_length, string); test_scalar_expr!(Replace, replace, string, from, to); test_scalar_expr!(Repeat, repeat, string, count); @@ -1101,7 +1081,6 @@ mod test { test_scalar_expr!(Right, right, string, count); test_nary_scalar_expr!(Rpad, rpad, string, count); test_nary_scalar_expr!(Rpad, rpad, string, count, characters); - test_scalar_expr!(Rtrim, rtrim, string); test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index b12c99e84a90..0410d89d123f 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -29,10 +29,9 @@ authors = { workspace = true } rust-version = { workspace = true } [features] -# enable string functions -string_expressions = [] # enable core functions core_expressions = [] +crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] # enable datetime functions datetime_expressions = [] # Enable encoding by default so the doctests work. In general don't automatically enable all packages. @@ -51,7 +50,9 @@ encoding_expressions = ["base64", "hex"] math_expressions = [] # enable regular expressions regex_expressions = ["regex"] -crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] +# enable string functions +string_expressions = [] + [lib] name = "datafusion_functions" path = "src/lib.rs" diff --git a/datafusion/functions/src/string/trim.rs b/datafusion/functions/src/string/btrim.rs similarity index 73% rename from datafusion/functions/src/string/trim.rs rename to datafusion/functions/src/string/btrim.rs index e04a171722e3..de1c9cc69b72 100644 --- a/datafusion/functions/src/string/trim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -16,30 +16,30 @@ // under the License. use arrow::array::{ArrayRef, OffsetSizeTrait}; -use arrow::datatypes::DataType; -use datafusion_common::exec_err; -use datafusion_common::Result; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use crate::string::{make_scalar_function, utf8_to_str_type}; +use arrow::datatypes::DataType; -use super::{general_trim, TrimType}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; -/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +use crate::string::common::*; + +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { +fn btrim(args: &[ArrayRef]) -> Result { general_trim::(args, TrimType::Both) } #[derive(Debug)] -pub(super) struct TrimFunc { +pub(super) struct BTrimFunc { signature: Signature, + aliases: Vec, } -impl TrimFunc { +impl BTrimFunc { pub fn new() -> Self { use DataType::*; Self { @@ -47,17 +47,18 @@ impl TrimFunc { vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], Volatility::Immutable, ), + aliases: vec![String::from("trim")], } } } -impl ScalarUDFImpl for TrimFunc { +impl ScalarUDFImpl for BTrimFunc { fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { - "trim" + "btrim" } fn signature(&self) -> &Signature { @@ -65,14 +66,18 @@ impl ScalarUDFImpl for TrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "trim") + utf8_to_str_type(&arg_types[0], "btrim") } fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { DataType::Utf8 => make_scalar_function(btrim::, vec![])(args), DataType::LargeUtf8 => make_scalar_function(btrim::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function trim"), + other => exec_err!("Unsupported data type {other:?} for function btrim"), } } + + fn aliases(&self) -> &[String] { + &self.aliases + } } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs new file mode 100644 index 000000000000..97465420fb99 --- /dev/null +++ b/datafusion/functions/src/string/common.rs @@ -0,0 +1,265 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::Result; +use datafusion_common::{exec_err, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_physical_expr::functions::Hint; + +pub(crate) enum TrimType { + Left, + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +pub(crate) fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + + match args.len() { + 1 => { + let result = string_array + .iter() + .map(|string| string.map(|string: &str| func(string, " "))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let characters_array = as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .map(|(string, characters)| match (string, characters) { + (Some(string), Some(characters)) => Some(func(string, characters)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } + } +} + +/// Creates a function to identify the optimal return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +macro_rules! get_optimal_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + DataType::Dictionary(_, value_type) => match **value_type { + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + _ => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + **value_type + ); + } + }, + data_type => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + data_type + ); + } + }) + } + }; +} + +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. +get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +/// applies a unary expression to `args[0]` that is expected to be downcastable to +/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) +/// # Errors +/// This function errors when: +/// * the number of arguments is not 1 +/// * the first argument is not castable to a `GenericStringArray` +pub(crate) fn unary_string_function<'a, T, O, F, R>( + args: &[&'a dyn Array], + op: F, + name: &str, +) -> Result> +where + R: AsRef, + O: OffsetSizeTrait, + T: OffsetSizeTrait, + F: Fn(&'a str) -> R, +{ + if args.len() != 1 { + return exec_err!( + "{:?} args were supplied but {} takes exactly one argument", + args.len(), + name + ); + } + + let string_array = as_generic_string_array::(args[0])?; + + // first map is the iterator, second is for the `Option<_>` + Ok(string_array.iter().map(|string| string.map(&op)).collect()) +} + +pub(crate) fn handle<'a, F, R>( + args: &'a [ColumnarValue], + op: F, + name: &str, +) -> Result +where + R: AsRef, + F: Fn(&'a str) -> R, +{ + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i32, + i32, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i64, + i64, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + } +} + +pub(super) fn make_scalar_function( + inner: F, + hints: Vec, +) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) + .map(|(arg, hint)| { + // Decide on the length to expand this scalar to depending + // on the given hints. + let expansion_len = match hint { + Hint::AcceptsSingular => 1, + Hint::Pad => inferred_length, + }; + arg.clone().into_array(expansion_len) + }) + .collect::>>()?; + + let result = (inner)(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + }) +} diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs new file mode 100644 index 000000000000..535ffb14f5f5 --- /dev/null +++ b/datafusion/functions/src/string/ltrim.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// ltrim('zzzytest', 'xyz') = 'test' +fn ltrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Left) +} + +#[derive(Debug)] +pub(super) struct LtrimFunc { + signature: Signature, +} + +impl LtrimFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LtrimFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ltrim" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "ltrim") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(ltrim::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(ltrim::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function ltrim"), + } + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 08fcbb363bbc..13c02d5dfac3 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -15,278 +15,63 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}, - datatypes::DataType, -}; -use datafusion_common::{ - cast::as_generic_string_array, exec_err, plan_err, Result, ScalarValue, -}; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; -use datafusion_physical_expr::functions::Hint; -use std::{ - fmt::{Display, Formatter}, - sync::Arc, -}; +//! "string" DataFusion functions -/// Creates a function to identify the optimal return type of a string function given -/// the type of its first argument. -/// -/// If the input type is `LargeUtf8` or `LargeBinary` the return type is -/// `$largeUtf8Type`, -/// -/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, -macro_rules! get_optimal_return_type { - ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { - fn $FUNC(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - // LargeBinary inputs are automatically coerced to Utf8 - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - // Binary inputs are automatically coerced to Utf8 - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - DataType::Dictionary(_, value_type) => match **value_type { - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - _ => { - return plan_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - **value_type - ); - } - }, - data_type => { - return plan_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - data_type - ); - } - }) - } - }; -} +use std::sync::Arc; -// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. -get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); +use datafusion_expr::ScalarUDF; -/// applies a unary expression to `args[0]` that is expected to be downcastable to -/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) -/// # Errors -/// This function errors when: -/// * the number of arguments is not 1 -/// * the first argument is not castable to a `GenericStringArray` -pub(crate) fn unary_string_function<'a, T, O, F, R>( - args: &[&'a dyn Array], - op: F, - name: &str, -) -> Result> -where - R: AsRef, - O: OffsetSizeTrait, - T: OffsetSizeTrait, - F: Fn(&'a str) -> R, -{ - if args.len() != 1 { - return exec_err!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - name - ); - } +mod btrim; +mod common; +mod ltrim; +mod rtrim; +mod starts_with; +mod to_hex; +mod upper; - let string_array = as_generic_string_array::(args[0])?; +// create UDFs +make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); +make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); +make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); +make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); +make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); +make_udf_function!(upper::UpperFunc, UPPER, upper); - // first map is the iterator, second is for the `Option<_>` - Ok(string_array.iter().map(|string| string.map(&op)).collect()) -} +pub mod expr_fn { + use datafusion_expr::Expr; -fn handle<'a, F, R>(args: &'a [ColumnarValue], op: F, name: &str) -> Result -where - R: AsRef, - F: Fn(&'a str) -> R, -{ - match &args[0] { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i32, - i32, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i64, - i64, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) - } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, + #[doc = "Removes all characters, spaces by default, from both sides of a string"] + pub fn btrim(args: Vec) -> Expr { + super::btrim().call(args) } -} - -// TODO: mode allow[(dead_code)] after move ltrim and rtrim -enum TrimType { - #[allow(dead_code)] - Left, - #[allow(dead_code)] - Right, - Both, -} -impl Display for TrimType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - TrimType::Left => write!(f, "ltrim"), - TrimType::Right => write!(f, "rtrim"), - TrimType::Both => write!(f, "btrim"), - } + #[doc = "Removes all characters, spaces by default, from the beginning of a string"] + pub fn ltrim(args: Vec) -> Expr { + super::ltrim().call(args) } -} - -fn general_trim( - args: &[ArrayRef], - trim_type: TrimType, -) -> Result { - let func = match trim_type { - TrimType::Left => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_start_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Right => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Both => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>( - str::trim_start_matches::<&[char]>(input, pattern.as_ref()), - pattern.as_ref(), - ) - }, - }; - - let string_array = as_generic_string_array::(&args[0])?; - match args.len() { - 1 => { - let result = string_array - .iter() - .map(|string| string.map(|string: &str| func(string, " "))) - .collect::>(); + #[doc = "Removes all characters, spaces by default, from the end of a string"] + pub fn rtrim(args: Vec) -> Expr { + super::rtrim().call(args) + } - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let characters_array = as_generic_string_array::(&args[1])?; + #[doc = "Returns true if string starts with prefix."] + pub fn starts_with(arg1: Expr, arg2: Expr) -> Expr { + super::starts_with().call(vec![arg1, arg2]) + } - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => Some(func(string, characters)), - _ => None, - }) - .collect::>(); + #[doc = "Converts an integer to a hexadecimal string."] + pub fn to_hex(arg1: Expr) -> Expr { + super::to_hex().call(vec![arg1]) + } - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!( - "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." - ) - } + #[doc = "Converts a string to uppercase."] + pub fn upper(arg1: Expr) -> Expr { + super::upper().call(vec![arg1]) } } -pub(super) fn make_scalar_function( - inner: F, - hints: Vec, -) -> ScalarFunctionImplementation -where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, -{ - Arc::new(move |args: &[ColumnarValue]| { - // first, identify if any of the arguments is an Array. If yes, store its `len`, - // as any scalar will need to be converted to an array of len `len`. - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let is_scalar = len.is_none(); - - let inferred_length = len.unwrap_or(1); - let args = args - .iter() - .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) - .map(|(arg, hint)| { - // Decide on the length to expand this scalar to depending - // on the given hints. - let expansion_len = match hint { - Hint::AcceptsSingular => 1, - Hint::Pad => inferred_length, - }; - arg.clone().into_array(expansion_len) - }) - .collect::>>()?; - - let result = (inner)(&args); - if is_scalar { - // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) - } else { - result.map(ColumnarValue::Array) - } - }) +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![btrim(), ltrim(), rtrim(), starts_with(), to_hex(), upper()] } - -mod starts_with; -mod to_hex; -mod trim; -mod upper; -// create UDFs -make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); -make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); -make_udf_function!(trim::TrimFunc, TRIM, trim); -make_udf_function!(upper::UpperFunc, UPPER, upper); - -export_functions!( - ( - starts_with, - arg1 arg2, - "Returns true if string starts with prefix."), - ( - to_hex, - arg1, - "Converts an integer to a hexadecimal string."), - (trim, - arg1, - "removes all characters, space by default from the string"), - (upper, - arg1, - "Converts a string to uppercase.")); diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs new file mode 100644 index 000000000000..17d2f8234b34 --- /dev/null +++ b/datafusion/functions/src/string/rtrim.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// rtrim('testxxzx', 'xyz') = 'test' +fn rtrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Right) +} + +#[derive(Debug)] +pub(super) struct RtrimFunc { + signature: Signature, +} + +impl RtrimFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RtrimFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rtrim" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "rtrim") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(rtrim::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(rtrim::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function rtrim"), + } + } +} diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 1fce399d1e70..4450b9d332a0 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::make_scalar_function; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; @@ -24,8 +25,6 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; -use crate::string::make_scalar_function; - /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' pub fn starts_with(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 4dfc84887da2..1bdece3f7af8 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::make_scalar_function; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, @@ -27,8 +28,6 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; -use super::make_scalar_function; - /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' pub fn to_hex(args: &[ArrayRef]) -> Result diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index ed41487699aa..a0c910ebb2c8 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -15,16 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::{handle, utf8_to_str_type}; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use crate::string::utf8_to_str_type; - -use super::handle; - #[derive(Debug)] pub(super) struct UpperFunc { signature: Signature, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index f2c93c3ec1dd..a6efe0e0861d 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -273,15 +273,6 @@ pub fn create_physical_fun( _ => unreachable!(), }, }), - BuiltinScalarFunction::Btrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function btrim"), - }), BuiltinScalarFunction::CharacterLength => { Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -347,15 +338,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function lpad"), }), - BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::ltrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::ltrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function ltrim"), - }), BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { @@ -427,15 +409,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function rpad"), }), - BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::rtrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::rtrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function rtrim"), - }), BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::split_part::)(args) @@ -752,70 +725,6 @@ mod tests { Int32Array ); test_function!(BitLength, &[lit("")], Ok(Some(0)), i32, Int32, Int32Array); - test_function!( - Btrim, - &[lit(" trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit(" trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("\n trim \n")], - Ok(Some("\n trim \n")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("xyxtrimyyx"), lit("xyz"),], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("\nxyxtrimyyx\n"), lit("xyz\n"),], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit(ScalarValue::Utf8(None)), lit("xyz"),], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("xyxtrimyyx"), lit(ScalarValue::Utf8(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); #[cfg(feature = "unicode_expressions")] test_function!( CharacterLength, @@ -1287,54 +1196,6 @@ mod tests { Utf8, StringArray ); - test_function!( - Ltrim, - &[lit(" trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit(" trim ")], - Ok(Some("trim ")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit("trim ")], - Ok(Some("trim ")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit("trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit("\n trim ")], - Ok(Some("\n trim ")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); test_function!( OctetLength, &[lit("chars")], @@ -1683,54 +1544,6 @@ mod tests { Utf8, StringArray ); - test_function!( - Rtrim, - &[lit("trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(" trim ")], - Ok(Some(" trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(" trim \n")], - Ok(Some(" trim \n")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(" trim")], - Ok(Some(" trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit("trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); test_function!( SplitPart, &[ diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 86c0092a220d..f5229d92545e 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -21,11 +21,8 @@ //! String expressions +use std::iter; use std::sync::Arc; -use std::{ - fmt::{Display, Formatter}, - iter, -}; use arrow::{ array::{ @@ -346,95 +343,6 @@ pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |string| string.to_lowercase(), "lower") } -enum TrimType { - Left, - Right, - Both, -} - -impl Display for TrimType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - TrimType::Left => write!(f, "ltrim"), - TrimType::Right => write!(f, "rtrim"), - TrimType::Both => write!(f, "btrim"), - } - } -} - -fn general_trim( - args: &[ArrayRef], - trim_type: TrimType, -) -> Result { - let func = match trim_type { - TrimType::Left => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_start_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Right => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Both => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>( - str::trim_start_matches::<&[char]>(input, pattern.as_ref()), - pattern.as_ref(), - ) - }, - }; - - let string_array = as_generic_string_array::(&args[0])?; - - match args.len() { - 1 => { - let result = string_array - .iter() - .map(|string| string.map(|string: &str| func(string, " "))) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => Some(func(string, characters)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!( - "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." - ) - } - } -} - -/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. -/// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Both) -} - -/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. -/// ltrim('zzzytest', 'xyz') = 'test' -pub fn ltrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Left) -} - -/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. -/// rtrim('testxxzx', 'xyz') = 'test' -pub fn rtrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Right) -} - /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c009682d5a4d..416b49db7aa7 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -564,7 +564,7 @@ enum ScalarFunction { // 20 was Array // RegexpMatch = 21; BitLength = 22; - Btrim = 23; + // 23 was Btrim CharacterLength = 24; Chr = 25; Concat = 26; @@ -575,7 +575,7 @@ enum ScalarFunction { Left = 31; Lpad = 32; Lower = 33; - Ltrim = 34; + // 34 was Ltrim // 35 was MD5 // 36 was NullIf OctetLength = 37; @@ -586,7 +586,7 @@ enum ScalarFunction { Reverse = 42; Right = 43; Rpad = 44; - Rtrim = 45; + // 45 was Rtrim // 46 was SHA224 // 47 was SHA256 // 48 was SHA384 diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 58683dba6dff..49102137b659 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22930,7 +22930,6 @@ impl serde::Serialize for ScalarFunction { Self::Sqrt => "Sqrt", Self::Trunc => "Trunc", Self::BitLength => "BitLength", - Self::Btrim => "Btrim", Self::CharacterLength => "CharacterLength", Self::Chr => "Chr", Self::Concat => "Concat", @@ -22939,7 +22938,6 @@ impl serde::Serialize for ScalarFunction { Self::Left => "Left", Self::Lpad => "Lpad", Self::Lower => "Lower", - Self::Ltrim => "Ltrim", Self::OctetLength => "OctetLength", Self::Random => "Random", Self::Repeat => "Repeat", @@ -22947,7 +22945,6 @@ impl serde::Serialize for ScalarFunction { Self::Reverse => "Reverse", Self::Right => "Right", Self::Rpad => "Rpad", - Self::Rtrim => "Rtrim", Self::SplitPart => "SplitPart", Self::Strpos => "Strpos", Self::Substr => "Substr", @@ -23004,7 +23001,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sqrt", "Trunc", "BitLength", - "Btrim", "CharacterLength", "Chr", "Concat", @@ -23013,7 +23009,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Left", "Lpad", "Lower", - "Ltrim", "OctetLength", "Random", "Repeat", @@ -23021,7 +23016,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Reverse", "Right", "Rpad", - "Rtrim", "SplitPart", "Strpos", "Substr", @@ -23107,7 +23101,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sqrt" => Ok(ScalarFunction::Sqrt), "Trunc" => Ok(ScalarFunction::Trunc), "BitLength" => Ok(ScalarFunction::BitLength), - "Btrim" => Ok(ScalarFunction::Btrim), "CharacterLength" => Ok(ScalarFunction::CharacterLength), "Chr" => Ok(ScalarFunction::Chr), "Concat" => Ok(ScalarFunction::Concat), @@ -23116,7 +23109,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Left" => Ok(ScalarFunction::Left), "Lpad" => Ok(ScalarFunction::Lpad), "Lower" => Ok(ScalarFunction::Lower), - "Ltrim" => Ok(ScalarFunction::Ltrim), "OctetLength" => Ok(ScalarFunction::OctetLength), "Random" => Ok(ScalarFunction::Random), "Repeat" => Ok(ScalarFunction::Repeat), @@ -23124,7 +23116,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Reverse" => Ok(ScalarFunction::Reverse), "Right" => Ok(ScalarFunction::Right), "Rpad" => Ok(ScalarFunction::Rpad), - "Rtrim" => Ok(ScalarFunction::Rtrim), "SplitPart" => Ok(ScalarFunction::SplitPart), "Strpos" => Ok(ScalarFunction::Strpos), "Substr" => Ok(ScalarFunction::Substr), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8eabb3b18603..5e458bfef016 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2863,7 +2863,7 @@ pub enum ScalarFunction { /// 20 was Array /// RegexpMatch = 21; BitLength = 22, - Btrim = 23, + /// 23 was Btrim CharacterLength = 24, Chr = 25, Concat = 26, @@ -2874,7 +2874,7 @@ pub enum ScalarFunction { Left = 31, Lpad = 32, Lower = 33, - Ltrim = 34, + /// 34 was Ltrim /// 35 was MD5 /// 36 was NullIf OctetLength = 37, @@ -2885,7 +2885,7 @@ pub enum ScalarFunction { Reverse = 42, Right = 43, Rpad = 44, - Rtrim = 45, + /// 45 was Rtrim /// 46 was SHA224 /// 47 was SHA256 /// 48 was SHA384 @@ -3003,7 +3003,6 @@ impl ScalarFunction { ScalarFunction::Sqrt => "Sqrt", ScalarFunction::Trunc => "Trunc", ScalarFunction::BitLength => "BitLength", - ScalarFunction::Btrim => "Btrim", ScalarFunction::CharacterLength => "CharacterLength", ScalarFunction::Chr => "Chr", ScalarFunction::Concat => "Concat", @@ -3012,7 +3011,6 @@ impl ScalarFunction { ScalarFunction::Left => "Left", ScalarFunction::Lpad => "Lpad", ScalarFunction::Lower => "Lower", - ScalarFunction::Ltrim => "Ltrim", ScalarFunction::OctetLength => "OctetLength", ScalarFunction::Random => "Random", ScalarFunction::Repeat => "Repeat", @@ -3020,7 +3018,6 @@ impl ScalarFunction { ScalarFunction::Reverse => "Reverse", ScalarFunction::Right => "Right", ScalarFunction::Rpad => "Rpad", - ScalarFunction::Rtrim => "Rtrim", ScalarFunction::SplitPart => "SplitPart", ScalarFunction::Strpos => "Strpos", ScalarFunction::Substr => "Substr", @@ -3071,7 +3068,6 @@ impl ScalarFunction { "Sqrt" => Some(Self::Sqrt), "Trunc" => Some(Self::Trunc), "BitLength" => Some(Self::BitLength), - "Btrim" => Some(Self::Btrim), "CharacterLength" => Some(Self::CharacterLength), "Chr" => Some(Self::Chr), "Concat" => Some(Self::Concat), @@ -3080,7 +3076,6 @@ impl ScalarFunction { "Left" => Some(Self::Left), "Lpad" => Some(Self::Lpad), "Lower" => Some(Self::Lower), - "Ltrim" => Some(Self::Ltrim), "OctetLength" => Some(Self::OctetLength), "Random" => Some(Self::Random), "Repeat" => Some(Self::Repeat), @@ -3088,7 +3083,6 @@ impl ScalarFunction { "Reverse" => Some(Self::Reverse), "Right" => Some(Self::Right), "Rpad" => Some(Self::Rpad), - "Rtrim" => Some(Self::Rtrim), "SplitPart" => Some(Self::SplitPart), "Strpos" => Some(Self::Strpos), "Substr" => Some(Self::Substr), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 64ceb37d2961..d41add915a96 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -48,17 +48,16 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, ascii, asinh, atan, atan2, atanh, bit_length, btrim, cbrt, ceil, - character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, - degrees, ends_with, exp, + acosh, ascii, asinh, atan, atan2, atanh, bit_length, cbrt, ceil, character_length, + chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, nanvl, octet_length, overlay, pi, power, radians, random, repeat, - replace, reverse, right, round, rpad, rtrim, signum, sin, sinh, split_part, sqrt, - strpos, substr, substr_index, substring, translate, trunc, uuid, AggregateFunction, - Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, + lower, lpad, nanvl, octet_length, overlay, pi, power, radians, random, repeat, + replace, reverse, right, round, rpad, signum, sin, sinh, split_part, sqrt, strpos, + substr, substr_index, substring, translate, trunc, uuid, AggregateFunction, Between, + BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, @@ -461,13 +460,10 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::OctetLength => Self::OctetLength, ScalarFunction::Concat => Self::Concat, ScalarFunction::Lower => Self::Lower, - ScalarFunction::Ltrim => Self::Ltrim, - ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, ScalarFunction::Ascii => Self::Ascii, ScalarFunction::BitLength => Self::BitLength, - ScalarFunction::Btrim => Self::Btrim, ScalarFunction::CharacterLength => Self::CharacterLength, ScalarFunction::Chr => Self::Chr, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, @@ -1439,12 +1435,6 @@ pub fn parse_expr( ScalarFunction::Lower => { Ok(lower(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Ltrim => { - Ok(ltrim(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Rtrim => { - Ok(rtrim(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Ascii => { Ok(ascii(parse_expr(&args[0], registry, codec)?)) } @@ -1512,12 +1502,6 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), - ScalarFunction::Btrim => Ok(btrim( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), ScalarFunction::SplitPart => Ok(split_part( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 89bd93550a04..39d663b6c59b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1481,13 +1481,10 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::OctetLength => Self::OctetLength, BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::Lower => Self::Lower, - BuiltinScalarFunction::Ltrim => Self::Ltrim, - BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, BuiltinScalarFunction::Ascii => Self::Ascii, BuiltinScalarFunction::BitLength => Self::BitLength, - BuiltinScalarFunction::Btrim => Self::Btrim, BuiltinScalarFunction::CharacterLength => Self::CharacterLength, BuiltinScalarFunction::Chr => Self::Chr, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index c34b42193cec..04f8001bfc1b 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,20 +15,11 @@ // specific language governing permissions and limitations // under the License. -mod binary_op; -mod function; -mod grouping_set; -mod identifier; -mod json_access; -mod order_by; -mod subquery; -mod substring; -mod unary_op; -mod value; - -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; use arrow_schema::TimeUnit; +use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; +use sqlparser::parser::ParserError::ParserError; + use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, DFSchema, Result, ScalarValue, @@ -40,8 +31,19 @@ use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, }; -use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; -use sqlparser::parser::ParserError::ParserError; + +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; + +mod binary_op; +mod function; +mod grouping_set; +mod identifier; +mod json_access; +mod order_by; +mod subquery; +mod substring; +mod unary_op; +mod value; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn sql_expr_to_logical_expr( @@ -743,13 +745,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = match trim_where { - Some(TrimWhereField::Leading) => BuiltinScalarFunction::Ltrim, - Some(TrimWhereField::Trailing) => BuiltinScalarFunction::Rtrim, - Some(TrimWhereField::Both) => BuiltinScalarFunction::Btrim, - None => BuiltinScalarFunction::Btrim, - }; - let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; let args = match (trim_what, trim_characters) { (Some(to_trim), None) => { @@ -774,7 +769,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } (None, None) => Ok(vec![arg]), }?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + + let fun_name = match trim_where { + Some(TrimWhereField::Leading) => "ltrim", + Some(TrimWhereField::Trailing) => "rtrim", + Some(TrimWhereField::Both) => "btrim", + None => "trim", + }; + let fun = self + .context_provider + .get_function_meta(fun_name) + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected '{fun_name}' function") + })?; + + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } fn sql_overlay_to_expr( diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d4570dbc35f2..5eb3436b4256 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -731,12 +731,15 @@ btrim(str[, trim_str]) Can be a constant, column, or function, and any combination of string operators. - **trim_str**: String expression to trim from the beginning and end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters_. + _Default is whitespace characters._ **Related functions**: [ltrim](#ltrim), -[rtrim](#rtrim), -[trim](#trim) +[rtrim](#rtrim) + +#### Aliases + +- trim ### `char_length` @@ -919,26 +922,25 @@ lpad(str, n[, padding_str]) ### `ltrim` -Removes leading spaces from a string. +Trims the specified trim string from the beginning of a string. +If no trim string is provided, all whitespace is removed from the start +of the input string. ``` -ltrim(str) +ltrim(str[, trim_str]) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. +- **trim_str**: String expression to trim from the beginning of the input string. + Can be a constant, column, or function, and any combination of arithmetic operators. + _Default is whitespace characters._ **Related functions**: [btrim](#btrim), -[rtrim](#rtrim), -[trim](#trim) - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +[rtrim](#rtrim) ### `octet_length` @@ -1040,21 +1042,25 @@ rpad(str, n[, padding_str]) ### `rtrim` -Removes trailing spaces from a string. +Trims the specified trim string from the end of a string. +If no trim string is provided, all whitespace is removed from the end +of the input string. ``` -rtrim(str) +rtrim(str[, trim_str]) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. +- **trim_str**: String expression to trim from the end of the input string. + Can be a constant, column, or function, and any combination of arithmetic operators. + _Default is whitespace characters._ **Related functions**: [btrim](#btrim), -[ltrim](#ltrim), -[trim](#trim) +[ltrim](#ltrim) ### `split_part` @@ -1154,21 +1160,7 @@ to_hex(int) ### `trim` -Removes leading and trailing spaces from a string. - -``` -trim(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[btrim](#btrim), -[ltrim](#ltrim), -[rtrim](#rtrim) +_Alias of [btrim](#btrim)._ ### `upper`