diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 2e734e4acc78..46df18b08970 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -46,6 +46,7 @@ num = "0.4" half = "1.8" csv_crate = { version = "1.1", optional = true, package="csv" } regex = "1.3" +regex-syntax = { version = "0.6.27", default-features = false, features = ["unicode"] } lazy_static = "1.4" packed_simd = { version = "0.3", optional = true, package = "packed_simd_2" } chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index b25676c6fc4b..5d696862f781 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -36,7 +36,7 @@ use crate::datatypes::{ }; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -use regex::{escape, Regex}; +use regex::Regex; use std::any::type_name; use std::collections::HashMap; @@ -267,7 +267,7 @@ where let re = if let Some(ref regex) = map.get(pat) { regex } else { - let re_pattern = escape(pat).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(pat)?; let re = op(&re_pattern)?; map.insert(pat, re); map.get(pat).unwrap() @@ -346,7 +346,9 @@ pub fn like_utf8_scalar( bit_util::set_bit(bool_slice, i); } } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use starts_with let starts_with = &right[..right.len() - 1]; @@ -364,7 +366,7 @@ pub fn like_utf8_scalar( } } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from LIKE pattern: {}", @@ -394,6 +396,43 @@ pub fn like_utf8_scalar( Ok(BooleanArray::from(data)) } +/// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does: +/// +/// 1. Replace like wildcards for regex expressions as the pattern will be evaluated using regex match: `%` => `.*` and `_` => `.` +/// 2. Escape regex meta characters to match them and not be evaluated as regex special chars. For example: `.` => `\\.` +/// 3. Replace escaped like wildcards removing the escape characters to be able to match it as a regex. For example: `\\%` => `%` +fn replace_like_wildcards(pattern: &str) -> Result { + let mut result = String::new(); + let pattern = String::from(pattern); + let mut chars_iter = pattern.chars().peekable(); + while let Some(c) = chars_iter.next() { + if c == '\\' { + let next = chars_iter.peek(); + match next { + Some(next) if is_like_pattern(*next) => { + result.push(*next); + // Skipping the next char as it is already appended + chars_iter.next(); + } + _ => { + result.push('\\'); + result.push('\\'); + } + } + } else if regex_syntax::is_meta_character(c) { + result.push('\\'); + result.push(c); + } else if c == '%' { + result.push_str(".*"); + } else if c == '_' { + result.push('.'); + } else { + result.push(c); + } + } + Ok(result) +} + /// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / /// [`LargeStringArray`]. /// @@ -428,7 +467,9 @@ pub fn nlike_utf8_scalar( for i in 0..left.len() { result.append(left.value(i) != right); } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use ends_with for i in 0..left.len() { @@ -440,7 +481,7 @@ pub fn nlike_utf8_scalar( result.append(!left.value(i).ends_with(&right[1..])); } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from LIKE pattern: {}", @@ -501,7 +542,9 @@ pub fn ilike_utf8_scalar( for i in 0..left.len() { result.append(left.value(i) == right); } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use ends_with for i in 0..left.len() { @@ -521,7 +564,7 @@ pub fn ilike_utf8_scalar( ); } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from ILIKE pattern: {}", @@ -582,7 +625,9 @@ pub fn nilike_utf8_scalar( for i in 0..left.len() { result.append(left.value(i) != right); } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use ends_with for i in 0..left.len() { @@ -604,7 +649,7 @@ pub fn nilike_utf8_scalar( ); } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from ILIKE pattern: {}", @@ -3927,6 +3972,50 @@ mod tests { vec![false, true, false, false] ); + test_utf8_scalar!( + test_utf8_scalar_like_escape, + vec!["a%", "a\\x"], + "a\\%", + like_utf8_scalar, + vec![true, false] + ); + + test_utf8!( + test_utf8_scalar_ilike_regex, + vec!["%%%"], + vec![r#"\%_\%"#], + ilike_utf8, + vec![true] + ); + + #[test] + fn test_replace_like_wildcards() { + let a_eq = "_%"; + let expected = "..*"; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + + #[test] + fn test_replace_like_wildcards_leave_like_meta_chars() { + let a_eq = "\\%\\_"; + let expected = "%_"; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + + #[test] + fn test_replace_like_wildcards_with_multiple_escape_chars() { + let a_eq = "\\\\%"; + let expected = "\\\\%"; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + + #[test] + fn test_replace_like_wildcards_escape_regex_meta_char() { + let a_eq = "."; + let expected = "\\."; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + test_utf8!( test_utf8_array_eq, vec!["arrow", "arrow", "arrow", "arrow"],