Skip to content

Commit

Permalink
Move trim functions (btrim, ltrim, rtrim) to datafusion_functions, ma…
Browse files Browse the repository at this point in the history
…ke expr_fn API consistent (#9730)

* Fix to_timestamp benchmark

* Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started.

* Move trim functions to datafusion-functions

* Doc updates for ltrim, rtrim and trim to reflect how they actually function.

* Fixed struct name Trim -> BTrim
  • Loading branch information
Omega359 authored Mar 22, 2024
1 parent 47f4b5a commit d321ba3
Show file tree
Hide file tree
Showing 21 changed files with 559 additions and 714 deletions.
4 changes: 2 additions & 2 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ async fn test_fn_lpad_with_string() -> Result<()> {

#[tokio::test]
async fn test_fn_ltrim() -> Result<()> {
let expr = ltrim(lit(" a b c "));
let expr = ltrim(vec![lit(" a b c ")]);

let expected = [
"+-----------------------------------------+",
Expand All @@ -384,7 +384,7 @@ async fn test_fn_ltrim() -> Result<()> {

#[tokio::test]
async fn test_fn_ltrim_with_columns() -> Result<()> {
let expr = ltrim(col("a"));
let expr = ltrim(vec![col("a")]);

let expected = [
"+---------------+",
Expand Down
27 changes: 0 additions & 27 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ pub enum BuiltinScalarFunction {
Ascii,
/// bit_length
BitLength,
/// btrim
Btrim,
/// character_length
CharacterLength,
/// chr
Expand All @@ -127,8 +125,6 @@ pub enum BuiltinScalarFunction {
Lpad,
/// lower
Lower,
/// ltrim
Ltrim,
/// octet_length
OctetLength,
/// random
Expand All @@ -143,8 +139,6 @@ pub enum BuiltinScalarFunction {
Right,
/// rpad
Rpad,
/// rtrim
Rtrim,
/// split_part
SplitPart,
/// strpos
Expand Down Expand Up @@ -248,7 +242,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
BuiltinScalarFunction::BitLength => Volatility::Immutable,
BuiltinScalarFunction::Btrim => Volatility::Immutable,
BuiltinScalarFunction::CharacterLength => Volatility::Immutable,
BuiltinScalarFunction::Chr => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
Expand All @@ -258,15 +251,13 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Left => Volatility::Immutable,
BuiltinScalarFunction::Lpad => Volatility::Immutable,
BuiltinScalarFunction::Lower => Volatility::Immutable,
BuiltinScalarFunction::Ltrim => Volatility::Immutable,
BuiltinScalarFunction::OctetLength => Volatility::Immutable,
BuiltinScalarFunction::Radians => Volatility::Immutable,
BuiltinScalarFunction::Repeat => Volatility::Immutable,
BuiltinScalarFunction::Replace => Volatility::Immutable,
BuiltinScalarFunction::Reverse => Volatility::Immutable,
BuiltinScalarFunction::Right => Volatility::Immutable,
BuiltinScalarFunction::Rpad => Volatility::Immutable,
BuiltinScalarFunction::Rtrim => Volatility::Immutable,
BuiltinScalarFunction::SplitPart => Volatility::Immutable,
BuiltinScalarFunction::Strpos => Volatility::Immutable,
BuiltinScalarFunction::Substr => Volatility::Immutable,
Expand Down Expand Up @@ -303,9 +294,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::BitLength => {
utf8_to_int_type(&input_expr_types[0], "bit_length")
}
BuiltinScalarFunction::Btrim => {
utf8_to_str_type(&input_expr_types[0], "btrim")
}
BuiltinScalarFunction::CharacterLength => {
utf8_to_int_type(&input_expr_types[0], "character_length")
}
Expand All @@ -325,9 +313,6 @@ impl BuiltinScalarFunction {
utf8_to_str_type(&input_expr_types[0], "lower")
}
BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"),
BuiltinScalarFunction::Ltrim => {
utf8_to_str_type(&input_expr_types[0], "ltrim")
}
BuiltinScalarFunction::OctetLength => {
utf8_to_int_type(&input_expr_types[0], "octet_length")
}
Expand All @@ -347,9 +332,6 @@ impl BuiltinScalarFunction {
utf8_to_str_type(&input_expr_types[0], "right")
}
BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"),
BuiltinScalarFunction::Rtrim => {
utf8_to_str_type(&input_expr_types[0], "rtrim")
}
BuiltinScalarFunction::SplitPart => {
utf8_to_str_type(&input_expr_types[0], "split_part")
}
Expand Down Expand Up @@ -456,12 +438,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Reverse => {
Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility())
}
BuiltinScalarFunction::Btrim
| BuiltinScalarFunction::Ltrim
| BuiltinScalarFunction::Rtrim => Signature::one_of(
vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])],
self.volatility(),
),
BuiltinScalarFunction::Chr => {
Signature::uniform(1, vec![Int64], self.volatility())
}
Expand Down Expand Up @@ -703,7 +679,6 @@ impl BuiltinScalarFunction {
// string functions
BuiltinScalarFunction::Ascii => &["ascii"],
BuiltinScalarFunction::BitLength => &["bit_length"],
BuiltinScalarFunction::Btrim => &["btrim"],
BuiltinScalarFunction::CharacterLength => {
&["character_length", "char_length", "length"]
}
Expand All @@ -715,14 +690,12 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Left => &["left"],
BuiltinScalarFunction::Lower => &["lower"],
BuiltinScalarFunction::Lpad => &["lpad"],
BuiltinScalarFunction::Ltrim => &["ltrim"],
BuiltinScalarFunction::OctetLength => &["octet_length"],
BuiltinScalarFunction::Repeat => &["repeat"],
BuiltinScalarFunction::Replace => &["replace"],
BuiltinScalarFunction::Reverse => &["reverse"],
BuiltinScalarFunction::Right => &["right"],
BuiltinScalarFunction::Rpad => &["rpad"],
BuiltinScalarFunction::Rtrim => &["rtrim"],
BuiltinScalarFunction::SplitPart => &["split_part"],
BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"],
BuiltinScalarFunction::Substr => &["substr"],
Expand Down
21 changes: 0 additions & 21 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,6 @@ scalar_expr!(
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`");
scalar_expr!(Lower, lower, string, "convert the string to lower case");
scalar_expr!(
Ltrim,
ltrim,
string,
"removes all characters, spaces by default, from the beginning of a string"
);
scalar_expr!(
OctetLength,
octet_length,
Expand All @@ -617,12 +611,6 @@ scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `fro
scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times");
scalar_expr!(Reverse, reverse, string, "reverses the `string`");
scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`");
scalar_expr!(
Rtrim,
rtrim,
string,
"removes all characters, spaces by default, from the end of a string"
);
scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index.");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`");
Expand All @@ -640,11 +628,6 @@ nary_scalar_expr!(
rpad,
"fill up a string to the length by appending the characters"
);
nary_scalar_expr!(
Btrim,
btrim,
"removes all characters, spaces by default, from both sides of a string"
);
nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL");
//there is a func concat_ws before, so use concat_ws_expr as name.c
nary_scalar_expr!(
Expand Down Expand Up @@ -1082,8 +1065,6 @@ mod test {

test_scalar_expr!(Ascii, ascii, input);
test_scalar_expr!(BitLength, bit_length, string);
test_nary_scalar_expr!(Btrim, btrim, string);
test_nary_scalar_expr!(Btrim, btrim, string, characters);
test_scalar_expr!(CharacterLength, character_length, string);
test_scalar_expr!(Chr, chr, string);
test_scalar_expr!(Gcd, gcd, arg_1, arg_2);
Expand All @@ -1093,15 +1074,13 @@ mod test {
test_scalar_expr!(Lower, lower, string);
test_nary_scalar_expr!(Lpad, lpad, string, count);
test_nary_scalar_expr!(Lpad, lpad, string, count, characters);
test_scalar_expr!(Ltrim, ltrim, string);
test_scalar_expr!(OctetLength, octet_length, string);
test_scalar_expr!(Replace, replace, string, from, to);
test_scalar_expr!(Repeat, repeat, string, count);
test_scalar_expr!(Reverse, reverse, string);
test_scalar_expr!(Right, right, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count);
test_nary_scalar_expr!(Rpad, rpad, string, count, characters);
test_scalar_expr!(Rtrim, rtrim, string);
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(EndsWith, ends_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
Expand Down
7 changes: 4 additions & 3 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ authors = { workspace = true }
rust-version = { workspace = true }

[features]
# enable string functions
string_expressions = []
# enable core functions
core_expressions = []
crypto_expressions = ["md-5", "sha2", "blake2", "blake3"]
# enable datetime functions
datetime_expressions = []
# Enable encoding by default so the doctests work. In general don't automatically enable all packages.
Expand All @@ -51,7 +50,9 @@ encoding_expressions = ["base64", "hex"]
math_expressions = []
# enable regular expressions
regex_expressions = ["regex"]
crypto_expressions = ["md-5", "sha2", "blake2", "blake3"]
# enable string functions
string_expressions = []

[lib]
name = "datafusion_functions"
path = "src/lib.rs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,63 +16,68 @@
// under the License.

use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion_common::exec_err;
use datafusion_common::Result;
use datafusion_expr::ColumnarValue;
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;

use crate::string::{make_scalar_function, utf8_to_str_type};
use arrow::datatypes::DataType;

use super::{general_trim, TrimType};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed.
use crate::string::common::*;

/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed.
/// btrim('xyxtrimyyx', 'xyz') = 'trim'
pub fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
general_trim::<T>(args, TrimType::Both)
}

#[derive(Debug)]
pub(super) struct TrimFunc {
pub(super) struct BTrimFunc {
signature: Signature,
aliases: Vec<String>,
}

impl TrimFunc {
impl BTrimFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])],
Volatility::Immutable,
),
aliases: vec![String::from("trim")],
}
}
}

impl ScalarUDFImpl for TrimFunc {
impl ScalarUDFImpl for BTrimFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"trim"
"btrim"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "trim")
utf8_to_str_type(&arg_types[0], "btrim")
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(btrim::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(btrim::<i64>, vec![])(args),
other => exec_err!("Unsupported data type {other:?} for function trim"),
other => exec_err!("Unsupported data type {other:?} for function btrim"),
}
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}
Loading

0 comments on commit d321ba3

Please sign in to comment.