From 7dedd02830152c462e9384f9d5bb7dda4987b9d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 20 Sep 2021 09:30:47 +0200 Subject: [PATCH] Improved performance of utf8 validation of large strings via `simdutf8` (-40%) (#426) --- Cargo.toml | 5 ++ benches/read_parquet.rs | 5 ++ parquet_integration/write_parquet.py | 3 + src/array/ord.rs | 71 ++++++++++++++++++------ src/array/specification.rs | 3 +- src/compute/cast/mod.rs | 4 +- src/compute/like.rs | 4 +- src/error.rs | 6 ++ src/ffi/schema.rs | 2 +- src/io/avro/read/deserialize.rs | 2 +- src/io/avro/read/util.rs | 2 +- src/io/csv/read/deserialize.rs | 14 ++--- src/io/parquet/read/schema/metadata.rs | 2 +- src/io/parquet/read/statistics/binary.rs | 4 +- 14 files changed, 93 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0c4d568fde2..b5e0175aa8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,8 +64,13 @@ avro-rs = { version = "0.13", optional = true, default_features = false } # for division/remainder optimization at runtime strength_reduce = { version = "0.2", optional = true } + +# For instruction multiversioning multiversion = { version = "0.6.1", optional = true } +# For SIMD utf8 validation +simdutf8 = "0.1.3" + [dev-dependencies] rand = "0.8" criterion = "0.3" diff --git a/benches/read_parquet.rs b/benches/read_parquet.rs index f455c66434b..55ce25fc96d 100644 --- a/benches/read_parquet.rs +++ b/benches/read_parquet.rs @@ -42,6 +42,11 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| read_decompressed_pages(&buffer, size * 8, 2).unwrap()) }); + let a = format!("read utf8 large 2^{}", i); + c.bench_function(&a, |b| { + b.iter(|| read_decompressed_pages(&buffer, size * 8, 6).unwrap()) + }); + let a = format!("read bool 2^{}", i); c.bench_function(&a, |b| { b.iter(|| read_decompressed_pages(&buffer, size * 8, 3).unwrap()) diff --git a/parquet_integration/write_parquet.py b/parquet_integration/write_parquet.py index 1cda4465af2..d941cb4e6df 100644 --- a/parquet_integration/write_parquet.py +++ b/parquet_integration/write_parquet.py @@ -10,6 +10,7 @@ def case_basic_nullable(size=1): float64 = [0.0, 1.0, None, 3.0, None, 5.0, 6.0, 7.0, None, 9.0] string = ["Hello", None, "aa", "", None, "abc", None, None, "def", "aaa"] boolean = [True, None, False, False, None, True, None, None, True, True] + string_large = ["ABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDšŸ˜ƒšŸŒššŸ•³šŸ‘Š"] * 10 fields = [ pa.field("int64", pa.int64()), @@ -18,6 +19,7 @@ def case_basic_nullable(size=1): pa.field("bool", pa.bool_()), pa.field("date", pa.timestamp("ms")), pa.field("uint32", pa.uint32()), + pa.field("string_large", pa.utf8()), ] schema = pa.schema(fields) @@ -29,6 +31,7 @@ def case_basic_nullable(size=1): "bool": boolean * size, "date": int64 * size, "uint32": int64 * size, + "string_large": string_large * size, }, schema, f"basic_nullable_{size*10}.parquet", diff --git a/src/array/ord.rs b/src/array/ord.rs index 17dd0bc4653..319af374ab1 100644 --- a/src/array/ord.rs +++ b/src/array/ord.rs @@ -48,51 +48,90 @@ where } fn compare_primitives(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left.as_any().downcast_ref::>().unwrap().clone(); - let right = right.as_any().downcast_ref::>().unwrap().clone(); + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); Box::new(move |i, j| total_cmp(&left.value(i), &right.value(j))) } fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left.as_any().downcast_ref::().unwrap().clone(); - let right = right.as_any().downcast_ref::().unwrap().clone(); + let left = left + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::() + .unwrap() + .clone(); Box::new(move |i, j| left.value(i).cmp(&right.value(j))) } fn compare_f32(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left.as_any().downcast_ref::>().unwrap().clone(); + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); let right = right .as_any() .downcast_ref::>() - .unwrap().clone(); + .unwrap() + .clone(); Box::new(move |i, j| total_cmp_f32(&left.value(i), &right.value(j))) } fn compare_f64(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left.as_any().downcast_ref::>().unwrap().clone(); + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); let right = right .as_any() .downcast_ref::>() - .unwrap().clone(); + .unwrap() + .clone(); Box::new(move |i, j| total_cmp_f64(&left.value(i), &right.value(j))) } fn compare_string(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left.as_any().downcast_ref::>().unwrap().clone(); - let right = right.as_any().downcast_ref::>().unwrap().clone(); + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); Box::new(move |i, j| left.value(i).cmp(right.value(j))) } fn compare_binary(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left.as_any().downcast_ref::>().unwrap().clone(); - let right = right.as_any().downcast_ref::>().unwrap().clone(); + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); Box::new(move |i, j| left.value(i).cmp(right.value(j))) } -fn compare_dict( - left: &DictionaryArray, - right: &DictionaryArray, -) -> Result +fn compare_dict(left: &DictionaryArray, right: &DictionaryArray) -> Result where K: DictionaryKey, { diff --git a/src/array/specification.rs b/src/array/specification.rs index dd1e8b77ad6..94b407c0328 100644 --- a/src/array/specification.rs +++ b/src/array/specification.rs @@ -79,7 +79,8 @@ pub fn check_offsets_and_utf8(offsets: &[O], values: &[u8]) -> usize let end = window[1].to_usize(); assert!(end <= values.len()); let slice = unsafe { std::slice::from_raw_parts(values.as_ptr().add(start), end - start) }; - std::str::from_utf8(slice).expect("A non-utf8 string was passed."); + #[cfg(not(feature = "simdutf8"))] + simdutf8::basic::from_utf8(slice).expect("A non-utf8 string was passed."); }); len } diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index 06666f1f1d0..62b900c018a 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -538,7 +538,7 @@ fn cast_with_options( // perf todo: the offsets are equal; we can speed-up this let iter = array .iter() - .map(|x| x.and_then(|x| std::str::from_utf8(x).ok())); + .map(|x| x.and_then(|x| simdutf8::basic::from_utf8(x).ok())); let array = Utf8Array::::from_trusted_len_iter(iter); Ok(Box::new(array)) @@ -574,7 +574,7 @@ fn cast_with_options( // perf todo: the offsets are equal; we can speed-up this let iter = array .iter() - .map(|x| x.and_then(|x| std::str::from_utf8(x).ok())); + .map(|x| x.and_then(|x| simdutf8::basic::from_utf8(x).ok())); let array = Utf8Array::::from_trusted_len_iter(iter); Ok(Box::new(array)) diff --git a/src/compute/like.rs b/src/compute/like.rs index 2e776267e2b..4480c122827 100644 --- a/src/compute/like.rs +++ b/src/compute/like.rs @@ -173,7 +173,7 @@ fn a_like_binary bool>( let pattern = if let Some(pattern) = map.get(pattern) { pattern } else { - let re_pattern = std::str::from_utf8(pattern) + let re_pattern = simdutf8::basic::from_utf8(pattern) .unwrap() .replace("%", ".*") .replace("_", "."); @@ -231,7 +231,7 @@ fn a_like_binary_scalar bool>( op: F, ) -> Result { let validity = lhs.validity(); - let pattern = std::str::from_utf8(rhs).map_err(|e| { + let pattern = simdutf8::basic::from_utf8(rhs).map_err(|e| { ArrowError::InvalidArgumentError(format!( "Unable to convert the LIKE pattern to string: {}", e diff --git a/src/error.rs b/src/error.rs index 06a3d963625..c1938453494 100644 --- a/src/error.rs +++ b/src/error.rs @@ -51,6 +51,12 @@ impl From for ArrowError { } } +impl From for ArrowError { + fn from(error: simdutf8::basic::Utf8Error) -> Self { + ArrowError::External("".to_string(), Box::new(error)) + } +} + impl Display for ArrowError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs index 2ae006e1722..a5915ec77aa 100644 --- a/src/ffi/schema.rs +++ b/src/ffi/schema.rs @@ -435,7 +435,7 @@ unsafe fn read_ne_i32(ptr: *const u8) -> i32 { unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str { let slice = std::slice::from_raw_parts(ptr, len); - std::str::from_utf8(slice).unwrap() + simdutf8::basic::from_utf8(slice).unwrap() } unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) { diff --git a/src/io/avro/read/deserialize.rs b/src/io/avro/read/deserialize.rs index 8004adae75e..514de185a59 100644 --- a/src/io/avro/read/deserialize.rs +++ b/src/io/avro/read/deserialize.rs @@ -104,7 +104,7 @@ pub fn deserialize(mut block: &[u8], rows: usize, schema: Arc) -> Result "Avro format contains a non-usize number of bytes".to_string(), ) })?; - let data = std::str::from_utf8(&block[..len])?; + let data = simdutf8::basic::from_utf8(&block[..len])?; block = &block[len..]; let array = array diff --git a/src/io/avro/read/util.rs b/src/io/avro/read/util.rs index 1bdcbd9fd0f..19934d4f3b2 100644 --- a/src/io/avro/read/util.rs +++ b/src/io/avro/read/util.rs @@ -74,7 +74,7 @@ pub fn read_schema(reader: &mut R) -> AvroResult<(Schema, Codec, [u8; 1 .get("avro.codec") .and_then(|codec| { if let Value::Bytes(ref bytes) = *codec { - std::str::from_utf8(bytes.as_ref()).ok() + simdutf8::basic::from_utf8(bytes.as_ref()).ok() } else { None } diff --git a/src/io/csv/read/deserialize.rs b/src/io/csv/read/deserialize.rs index 57e43f29a65..d543c77f5d9 100644 --- a/src/io/csv/read/deserialize.rs +++ b/src/io/csv/read/deserialize.rs @@ -52,7 +52,7 @@ where fn deserialize_utf8(rows: &[ByteRecord], column: usize) -> Arc { let iter = rows.iter().map(|row| match row.get(column) { - Some(bytes) => std::str::from_utf8(bytes).ok(), + Some(bytes) => simdutf8::basic::from_utf8(bytes).ok(), None => None, }); Arc::new(Utf8Array::::from_trusted_len_iter(iter)) @@ -111,20 +111,20 @@ pub fn deserialize_column( lexical_core::parse::(bytes).ok() }), Date32 => deserialize_primitive(rows, column, datatype, |bytes| { - std::str::from_utf8(bytes) + simdutf8::basic::from_utf8(bytes) .ok() .and_then(|x| x.parse::().ok()) .map(|x| x.num_days_from_ce() - EPOCH_DAYS_FROM_CE) }), Date64 => deserialize_primitive(rows, column, datatype, |bytes| { - std::str::from_utf8(bytes) + simdutf8::basic::from_utf8(bytes) .ok() .and_then(|x| x.parse::().ok()) .map(|x| x.timestamp_millis()) }), Timestamp(TimeUnit::Nanosecond, None) => { deserialize_primitive(rows, column, datatype, |bytes| { - std::str::from_utf8(bytes) + simdutf8::basic::from_utf8(bytes) .ok() .and_then(|x| x.parse::().ok()) .map(|x| x.timestamp_nanos()) @@ -132,7 +132,7 @@ pub fn deserialize_column( } Timestamp(TimeUnit::Microsecond, None) => { deserialize_primitive(rows, column, datatype, |bytes| { - std::str::from_utf8(bytes) + simdutf8::basic::from_utf8(bytes) .ok() .and_then(|x| x.parse::().ok()) .map(|x| x.timestamp_nanos() / 1000) @@ -140,7 +140,7 @@ pub fn deserialize_column( } Timestamp(TimeUnit::Millisecond, None) => { deserialize_primitive(rows, column, datatype, |bytes| { - std::str::from_utf8(bytes) + simdutf8::basic::from_utf8(bytes) .ok() .and_then(|x| x.parse::().ok()) .map(|x| x.timestamp_nanos() / 1_000_000) @@ -148,7 +148,7 @@ pub fn deserialize_column( } Timestamp(TimeUnit::Second, None) => { deserialize_primitive(rows, column, datatype, |bytes| { - std::str::from_utf8(bytes) + simdutf8::basic::from_utf8(bytes) .ok() .and_then(|x| x.parse::().ok()) .map(|x| x.timestamp_nanos() / 1_000_000_000) diff --git a/src/io/parquet/read/schema/metadata.rs b/src/io/parquet/read/schema/metadata.rs index 5e7d310ba1c..46c58520768 100644 --- a/src/io/parquet/read/schema/metadata.rs +++ b/src/io/parquet/read/schema/metadata.rs @@ -113,7 +113,7 @@ mod tests { .collect::>(); assert_eq!( names, - vec!["int64", "float64", "string", "bool", "date", "uint32"] + vec!["int64", "float64", "string", "bool", "date", "uint32", "string_large"] ); Ok(()) } diff --git a/src/io/parquet/read/statistics/binary.rs b/src/io/parquet/read/statistics/binary.rs index 4f883fa6e11..db0e40276ab 100644 --- a/src/io/parquet/read/statistics/binary.rs +++ b/src/io/parquet/read/statistics/binary.rs @@ -57,12 +57,12 @@ impl TryFrom<&ParquetByteArrayStatistics> for Utf8Statistics { min_value: stats .min_value .as_ref() - .map(|x| std::str::from_utf8(x).map(|x| x.to_string())) + .map(|x| simdutf8::basic::from_utf8(x).map(|x| x.to_string())) .transpose()?, max_value: stats .max_value .as_ref() - .map(|x| std::str::from_utf8(x).map(|x| x.to_string())) + .map(|x| simdutf8::basic::from_utf8(x).map(|x| x.to_string())) .transpose()?, }) }