Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Improved performance of utf8 validation of large strings via `simdutf…
Browse files Browse the repository at this point in the history
…8` (-40%) (#426)
  • Loading branch information
Dandandan authored Sep 20, 2021
1 parent cb06aea commit 7dedd02
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 34 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,13 @@ avro-rs = { version = "0.13", optional = true, default_features = false }

# for division/remainder optimization at runtime
strength_reduce = { version = "0.2", optional = true }

# For instruction multiversioning
multiversion = { version = "0.6.1", optional = true }

# For SIMD utf8 validation
simdutf8 = "0.1.3"

[dev-dependencies]
rand = "0.8"
criterion = "0.3"
Expand Down
5 changes: 5 additions & 0 deletions benches/read_parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ fn add_benchmark(c: &mut Criterion) {
b.iter(|| read_decompressed_pages(&buffer, size * 8, 2).unwrap())
});

let a = format!("read utf8 large 2^{}", i);
c.bench_function(&a, |b| {
b.iter(|| read_decompressed_pages(&buffer, size * 8, 6).unwrap())
});

let a = format!("read bool 2^{}", i);
c.bench_function(&a, |b| {
b.iter(|| read_decompressed_pages(&buffer, size * 8, 3).unwrap())
Expand Down
3 changes: 3 additions & 0 deletions parquet_integration/write_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def case_basic_nullable(size=1):
float64 = [0.0, 1.0, None, 3.0, None, 5.0, 6.0, 7.0, None, 9.0]
string = ["Hello", None, "aa", "", None, "abc", None, None, "def", "aaa"]
boolean = [True, None, False, False, None, True, None, None, True, True]
string_large = ["ABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCD😃🌚🕳👊"] * 10

fields = [
pa.field("int64", pa.int64()),
Expand All @@ -18,6 +19,7 @@ def case_basic_nullable(size=1):
pa.field("bool", pa.bool_()),
pa.field("date", pa.timestamp("ms")),
pa.field("uint32", pa.uint32()),
pa.field("string_large", pa.utf8()),
]
schema = pa.schema(fields)

Expand All @@ -29,6 +31,7 @@ def case_basic_nullable(size=1):
"bool": boolean * size,
"date": int64 * size,
"uint32": int64 * size,
"string_large": string_large * size,
},
schema,
f"basic_nullable_{size*10}.parquet",
Expand Down
71 changes: 55 additions & 16 deletions src/array/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,51 +48,90 @@ where
}

fn compare_primitives<T: NativeType + Ord>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap().clone();
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
Box::new(move |i, j| total_cmp(&left.value(i), &right.value(j)))
}

fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<BooleanArray>().unwrap().clone();
let right = right.as_any().downcast_ref::<BooleanArray>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.clone();
Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
}

fn compare_f32(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<PrimitiveArray<f32>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<PrimitiveArray<f32>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<f32>>()
.unwrap().clone();
.unwrap()
.clone();
Box::new(move |i, j| total_cmp_f32(&left.value(i), &right.value(j)))
}

fn compare_f64(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap().clone();
.unwrap()
.clone();
Box::new(move |i, j| total_cmp_f64(&left.value(i), &right.value(j)))
}

fn compare_string<O: Offset>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<Utf8Array<O>>().unwrap().clone();
let right = right.as_any().downcast_ref::<Utf8Array<O>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<Utf8Array<O>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<Utf8Array<O>>()
.unwrap()
.clone();
Box::new(move |i, j| left.value(i).cmp(right.value(j)))
}

fn compare_binary<O: Offset>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<BinaryArray<O>>().unwrap().clone();
let right = right.as_any().downcast_ref::<BinaryArray<O>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<BinaryArray<O>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<BinaryArray<O>>()
.unwrap()
.clone();
Box::new(move |i, j| left.value(i).cmp(right.value(j)))
}

fn compare_dict<K>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
) -> Result<DynComparator>
fn compare_dict<K>(left: &DictionaryArray<K>, right: &DictionaryArray<K>) -> Result<DynComparator>
where
K: DictionaryKey,
{
Expand Down
3 changes: 2 additions & 1 deletion src/array/specification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ pub fn check_offsets_and_utf8<O: Offset>(offsets: &[O], values: &[u8]) -> usize
let end = window[1].to_usize();
assert!(end <= values.len());
let slice = unsafe { std::slice::from_raw_parts(values.as_ptr().add(start), end - start) };
std::str::from_utf8(slice).expect("A non-utf8 string was passed.");
#[cfg(not(feature = "simdutf8"))]
simdutf8::basic::from_utf8(slice).expect("A non-utf8 string was passed.");
});
len
}
4 changes: 2 additions & 2 deletions src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ fn cast_with_options(
// perf todo: the offsets are equal; we can speed-up this
let iter = array
.iter()
.map(|x| x.and_then(|x| std::str::from_utf8(x).ok()));
.map(|x| x.and_then(|x| simdutf8::basic::from_utf8(x).ok()));

let array = Utf8Array::<i32>::from_trusted_len_iter(iter);
Ok(Box::new(array))
Expand Down Expand Up @@ -574,7 +574,7 @@ fn cast_with_options(
// perf todo: the offsets are equal; we can speed-up this
let iter = array
.iter()
.map(|x| x.and_then(|x| std::str::from_utf8(x).ok()));
.map(|x| x.and_then(|x| simdutf8::basic::from_utf8(x).ok()));

let array = Utf8Array::<i64>::from_trusted_len_iter(iter);
Ok(Box::new(array))
Expand Down
4 changes: 2 additions & 2 deletions src/compute/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ fn a_like_binary<O: Offset, F: Fn(bool) -> bool>(
let pattern = if let Some(pattern) = map.get(pattern) {
pattern
} else {
let re_pattern = std::str::from_utf8(pattern)
let re_pattern = simdutf8::basic::from_utf8(pattern)
.unwrap()
.replace("%", ".*")
.replace("_", ".");
Expand Down Expand Up @@ -231,7 +231,7 @@ fn a_like_binary_scalar<O: Offset, F: Fn(bool) -> bool>(
op: F,
) -> Result<BooleanArray> {
let validity = lhs.validity();
let pattern = std::str::from_utf8(rhs).map_err(|e| {
let pattern = simdutf8::basic::from_utf8(rhs).map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"Unable to convert the LIKE pattern to string: {}",
e
Expand Down
6 changes: 6 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ impl From<std::str::Utf8Error> for ArrowError {
}
}

impl From<simdutf8::basic::Utf8Error> for ArrowError {
fn from(error: simdutf8::basic::Utf8Error) -> Self {
ArrowError::External("".to_string(), Box::new(error))
}
}

impl Display for ArrowError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down
2 changes: 1 addition & 1 deletion src/ffi/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ unsafe fn read_ne_i32(ptr: *const u8) -> i32 {

unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str {
let slice = std::slice::from_raw_parts(ptr, len);
std::str::from_utf8(slice).unwrap()
simdutf8::basic::from_utf8(slice).unwrap()
}

unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) {
Expand Down
2 changes: 1 addition & 1 deletion src/io/avro/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub fn deserialize(mut block: &[u8], rows: usize, schema: Arc<Schema>) -> Result
"Avro format contains a non-usize number of bytes".to_string(),
)
})?;
let data = std::str::from_utf8(&block[..len])?;
let data = simdutf8::basic::from_utf8(&block[..len])?;
block = &block[len..];

let array = array
Expand Down
2 changes: 1 addition & 1 deletion src/io/avro/read/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn read_schema<R: Read>(reader: &mut R) -> AvroResult<(Schema, Codec, [u8; 1
.get("avro.codec")
.and_then(|codec| {
if let Value::Bytes(ref bytes) = *codec {
std::str::from_utf8(bytes.as_ref()).ok()
simdutf8::basic::from_utf8(bytes.as_ref()).ok()
} else {
None
}
Expand Down
14 changes: 7 additions & 7 deletions src/io/csv/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ where

fn deserialize_utf8<O: Offset>(rows: &[ByteRecord], column: usize) -> Arc<dyn Array> {
let iter = rows.iter().map(|row| match row.get(column) {
Some(bytes) => std::str::from_utf8(bytes).ok(),
Some(bytes) => simdutf8::basic::from_utf8(bytes).ok(),
None => None,
});
Arc::new(Utf8Array::<O>::from_trusted_len_iter(iter))
Expand Down Expand Up @@ -111,44 +111,44 @@ pub fn deserialize_column(
lexical_core::parse::<f64>(bytes).ok()
}),
Date32 => deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDate>().ok())
.map(|x| x.num_days_from_ce() - EPOCH_DAYS_FROM_CE)
}),
Date64 => deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_millis())
}),
Timestamp(TimeUnit::Nanosecond, None) => {
deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos())
})
}
Timestamp(TimeUnit::Microsecond, None) => {
deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos() / 1000)
})
}
Timestamp(TimeUnit::Millisecond, None) => {
deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos() / 1_000_000)
})
}
Timestamp(TimeUnit::Second, None) => {
deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos() / 1_000_000_000)
Expand Down
2 changes: 1 addition & 1 deletion src/io/parquet/read/schema/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ mod tests {
.collect::<Vec<_>>();
assert_eq!(
names,
vec!["int64", "float64", "string", "bool", "date", "uint32"]
vec!["int64", "float64", "string", "bool", "date", "uint32", "string_large"]
);
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions src/io/parquet/read/statistics/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ impl TryFrom<&ParquetByteArrayStatistics> for Utf8Statistics {
min_value: stats
.min_value
.as_ref()
.map(|x| std::str::from_utf8(x).map(|x| x.to_string()))
.map(|x| simdutf8::basic::from_utf8(x).map(|x| x.to_string()))
.transpose()?,
max_value: stats
.max_value
.as_ref()
.map(|x| std::str::from_utf8(x).map(|x| x.to_string()))
.map(|x| simdutf8::basic::from_utf8(x).map(|x| x.to_string()))
.transpose()?,
})
}
Expand Down

0 comments on commit 7dedd02

Please sign in to comment.