diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index cf10b18ae3383..76c65716a88f8 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -18,17 +18,15 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, -}; +use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use crate::string::common::StringArrayType; +use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use crate::utils::{make_scalar_function, utf8_to_int_type}; - #[derive(Debug)] pub struct StrposFunc { signature: Signature, @@ -140,24 +138,43 @@ fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>( substring_array: V2, ) -> Result where - V1: ArrayAccessor, - V2: ArrayAccessor, + V1: StringArrayType<'a, Item = &'a str>, + V2: StringArrayType<'a, Item = &'a str>, { - let string_iter = ArrayIter::new(string_array); - let substring_iter = ArrayIter::new(substring_array); + let ascii_only = string_array.is_ascii() && substring_array.is_ascii(); + let string_iter = string_array.iter(); + let substring_iter = substring_array.iter(); let result = string_iter .zip(substring_iter) .map(|(string, substring)| match (string, substring) { (Some(string), Some(substring)) => { - // The `find` method returns the byte index of the substring. - // We count the number of chars up to that byte index. - T::Native::from_usize( - string - .find(substring) - .map(|x| string[..x].chars().count() + 1) - .unwrap_or(0), - ) + // If only ASCII characters are present, we can use the slide window method to find + // the sub vector in the main vector. This is faster than string.find() method. + if ascii_only { + // If the substring is empty, the result is 1. + if substring.as_bytes().is_empty() { + return T::Native::from_usize(1); + } else { + T::Native::from_usize( + string + .as_bytes() + .windows(substring.as_bytes().len()) + .position(|w| w == substring.as_bytes()) + .map(|x| x + 1) + .unwrap_or(0), + ) + } + } else { + // The `find` method returns the byte index of the substring. + // We count the number of chars up to that byte index. + T::Native::from_usize( + string + .find(substring) + .map(|x| string[..x].chars().count() + 1) + .unwrap_or(0), + ) + } } _ => None, }) @@ -201,6 +218,8 @@ mod tests { test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array); test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array); test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array); // LargeUtf8 and LargeUtf8 combinations test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); @@ -208,6 +227,8 @@ mod tests { test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); // Utf8 and LargeUtf8 combinations test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array); @@ -215,6 +236,8 @@ mod tests { test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array); // LargeUtf8 and Utf8 combinations test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array); @@ -222,6 +245,8 @@ mod tests { test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array); // Utf8View and Utf8View combinations test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array); @@ -229,6 +254,8 @@ mod tests { test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array); test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array); test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array); // Utf8View and Utf8 combinations test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array); @@ -236,6 +263,8 @@ mod tests { test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array); test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array); test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array); // Utf8View and LargeUtf8 combinations test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array); @@ -243,5 +272,7 @@ mod tests { test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array); } }