Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move trim functions (btrim, ltrim, rtrim) to datafusion_functions, make expr_fn API consistent #9730

Merged
merged 11 commits into from
Mar 22, 2024
Merged
4 changes: 2 additions & 2 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ async fn test_fn_lpad_with_string() -> Result<()> {

#[tokio::test]
async fn test_fn_ltrim() -> Result<()> {
let expr = ltrim(lit(" a b c "));
let expr = ltrim(vec![lit(" a b c ")]);

let expected = [
"+-----------------------------------------+",
Expand All @@ -384,7 +384,7 @@ async fn test_fn_ltrim() -> Result<()> {

#[tokio::test]
async fn test_fn_ltrim_with_columns() -> Result<()> {
let expr = ltrim(col("a"));
let expr = ltrim(vec![col("a")]);

let expected = [
"+---------------+",
Expand Down
27 changes: 0 additions & 27 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ pub enum BuiltinScalarFunction {
Ascii,
/// bit_length
BitLength,
/// btrim
Btrim,
/// character_length
CharacterLength,
/// chr
Expand All @@ -127,8 +125,6 @@ pub enum BuiltinScalarFunction {
Lpad,
/// lower
Lower,
/// ltrim
Ltrim,
/// octet_length
OctetLength,
/// random
Expand All @@ -143,8 +139,6 @@ pub enum BuiltinScalarFunction {
Right,
/// rpad
Rpad,
/// rtrim
Rtrim,
/// split_part
SplitPart,
/// strpos
Expand Down Expand Up @@ -248,7 +242,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
BuiltinScalarFunction::BitLength => Volatility::Immutable,
BuiltinScalarFunction::Btrim => Volatility::Immutable,
BuiltinScalarFunction::CharacterLength => Volatility::Immutable,
BuiltinScalarFunction::Chr => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
Expand All @@ -258,15 +251,13 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Left => Volatility::Immutable,
BuiltinScalarFunction::Lpad => Volatility::Immutable,
BuiltinScalarFunction::Lower => Volatility::Immutable,
BuiltinScalarFunction::Ltrim => Volatility::Immutable,
BuiltinScalarFunction::OctetLength => Volatility::Immutable,
BuiltinScalarFunction::Radians => Volatility::Immutable,
BuiltinScalarFunction::Repeat => Volatility::Immutable,
BuiltinScalarFunction::Replace => Volatility::Immutable,
BuiltinScalarFunction::Reverse => Volatility::Immutable,
BuiltinScalarFunction::Right => Volatility::Immutable,
BuiltinScalarFunction::Rpad => Volatility::Immutable,
BuiltinScalarFunction::Rtrim => Volatility::Immutable,
BuiltinScalarFunction::SplitPart => Volatility::Immutable,
BuiltinScalarFunction::Strpos => Volatility::Immutable,
BuiltinScalarFunction::Substr => Volatility::Immutable,
Expand Down Expand Up @@ -303,9 +294,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::BitLength => {
utf8_to_int_type(&input_expr_types[0], "bit_length")
}
BuiltinScalarFunction::Btrim => {
utf8_to_str_type(&input_expr_types[0], "btrim")
}
BuiltinScalarFunction::CharacterLength => {
utf8_to_int_type(&input_expr_types[0], "character_length")
}
Expand All @@ -325,9 +313,6 @@ impl BuiltinScalarFunction {
utf8_to_str_type(&input_expr_types[0], "lower")
}
BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"),
BuiltinScalarFunction::Ltrim => {
utf8_to_str_type(&input_expr_types[0], "ltrim")
}
BuiltinScalarFunction::OctetLength => {
utf8_to_int_type(&input_expr_types[0], "octet_length")
}
Expand All @@ -347,9 +332,6 @@ impl BuiltinScalarFunction {
utf8_to_str_type(&input_expr_types[0], "right")
}
BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"),
BuiltinScalarFunction::Rtrim => {
utf8_to_str_type(&input_expr_types[0], "rtrim")
}
BuiltinScalarFunction::SplitPart => {
utf8_to_str_type(&input_expr_types[0], "split_part")
}
Expand Down Expand Up @@ -456,12 +438,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Reverse => {
Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility())
}
BuiltinScalarFunction::Btrim
| BuiltinScalarFunction::Ltrim
| BuiltinScalarFunction::Rtrim => Signature::one_of(
vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])],
self.volatility(),
),
BuiltinScalarFunction::Chr => {
Signature::uniform(1, vec![Int64], self.volatility())
}
Expand Down Expand Up @@ -703,7 +679,6 @@ impl BuiltinScalarFunction {
// string functions
BuiltinScalarFunction::Ascii => &["ascii"],
BuiltinScalarFunction::BitLength => &["bit_length"],
BuiltinScalarFunction::Btrim => &["btrim"],
BuiltinScalarFunction::CharacterLength => {
&["character_length", "char_length", "length"]
}
Expand All @@ -715,14 +690,12 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Left => &["left"],
BuiltinScalarFunction::Lower => &["lower"],
BuiltinScalarFunction::Lpad => &["lpad"],
BuiltinScalarFunction::Ltrim => &["ltrim"],
BuiltinScalarFunction::OctetLength => &["octet_length"],
BuiltinScalarFunction::Repeat => &["repeat"],
BuiltinScalarFunction::Replace => &["replace"],
BuiltinScalarFunction::Reverse => &["reverse"],
BuiltinScalarFunction::Right => &["right"],
BuiltinScalarFunction::Rpad => &["rpad"],
BuiltinScalarFunction::Rtrim => &["rtrim"],
BuiltinScalarFunction::SplitPart => &["split_part"],
BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"],
BuiltinScalarFunction::Substr => &["substr"],
Expand Down
21 changes: 0 additions & 21 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,6 @@ scalar_expr!(
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`");
scalar_expr!(Lower, lower, string, "convert the string to lower case");
scalar_expr!(
Ltrim,
ltrim,
string,
"removes all characters, spaces by default, from the beginning of a string"
);
scalar_expr!(
OctetLength,
octet_length,
Expand All @@ -617,12 +611,6 @@ scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `fro
scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times");
scalar_expr!(Reverse, reverse, string, "reverses the `string`");
scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`");
scalar_expr!(
Rtrim,
rtrim,
string,
"removes all characters, spaces by default, from the end of a string"
);
scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index.");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`");
Expand All @@ -640,11 +628,6 @@ nary_scalar_expr!(
rpad,
"fill up a string to the length by appending the characters"
);
nary_scalar_expr!(
Btrim,
btrim,
"removes all characters, spaces by default, from both sides of a string"
);
nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL");
//there is a func concat_ws before, so use concat_ws_expr as name.c
nary_scalar_expr!(
Expand Down Expand Up @@ -1082,8 +1065,6 @@ mod test {

test_scalar_expr!(Ascii, ascii, input);
test_scalar_expr!(BitLength, bit_length, string);
test_nary_scalar_expr!(Btrim, btrim, string);
test_nary_scalar_expr!(Btrim, btrim, string, characters);
test_scalar_expr!(CharacterLength, character_length, string);
test_scalar_expr!(Chr, chr, string);
test_scalar_expr!(Gcd, gcd, arg_1, arg_2);
Expand All @@ -1093,15 +1074,13 @@ mod test {
test_scalar_expr!(Lower, lower, string);
test_nary_scalar_expr!(Lpad, lpad, string, count);
test_nary_scalar_expr!(Lpad, lpad, string, count, characters);
test_scalar_expr!(Ltrim, ltrim, string);
test_scalar_expr!(OctetLength, octet_length, string);
test_scalar_expr!(Replace, replace, string, from, to);
test_scalar_expr!(Repeat, repeat, string, count);
test_scalar_expr!(Reverse, reverse, string);
test_scalar_expr!(Right, right, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count, characters);
test_scalar_expr!(Rtrim, rtrim, string);
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(EndsWith, ends_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
Expand Down
7 changes: 4 additions & 3 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ authors = { workspace = true }
rust-version = { workspace = true }

[features]
# enable string functions
string_expressions = []
# enable core functions
core_expressions = []
crypto_expressions = ["md-5", "sha2", "blake2", "blake3"]
# enable datetime functions
datetime_expressions = []
# Enable encoding by default so the doctests work. In general don't automatically enable all packages.
Expand All @@ -51,7 +50,9 @@ encoding_expressions = ["base64", "hex"]
math_expressions = []
# enable regular expressions
regex_expressions = ["regex"]
crypto_expressions = ["md-5", "sha2", "blake2", "blake3"]
# enable string functions
string_expressions = []

[lib]
name = "datafusion_functions"
path = "src/lib.rs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,27 @@
// under the License.

use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion_common::exec_err;
use datafusion_common::Result;
use datafusion_expr::ColumnarValue;
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;

use crate::string::{make_scalar_function, utf8_to_str_type};
use arrow::datatypes::DataType;

use super::{general_trim, TrimType};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed.
use crate::string::common::*;

/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed.
/// btrim('xyxtrimyyx', 'xyz') = 'trim'
pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
general_trim::<T>(args, TrimType::Both)
}

#[derive(Debug)]
pub(super) struct TrimFunc {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't a huge deal, but I found it a little confusing that the function is called TrimFunc but has a name of "btrim" and an alias of "trim"

I would expect either BTrimFunc or else the name of "trim" and an alias of "btrim"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Urgh, this was an oversight on my part when moving the Trim function here. I'll fix.

signature: Signature,
aliases: Vec<String>,
}

impl TrimFunc {
Expand All @@ -47,6 +47,7 @@ impl TrimFunc {
vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])],
Volatility::Immutable,
),
aliases: vec![String::from("trim")],
}
}
}
Expand All @@ -57,22 +58,26 @@ impl ScalarUDFImpl for TrimFunc {
}

fn name(&self) -> &str {
"trim"
"btrim"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "trim")
utf8_to_str_type(&arg_types[0], "btrim")
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(btrim::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(btrim::<i64>, vec![])(args),
other => exec_err!("Unsupported data type {other:?} for function trim"),
other => exec_err!("Unsupported data type {other:?} for function btrim"),
}
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}
Loading
Loading