Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Improved performance of utf8 validation of large strings via simdutf8 (-40%) #426

Merged
merged 7 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,13 @@ avro-rs = { version = "0.13", optional = true, default_features = false }

# for division/remainder optimization at runtime
strength_reduce = { version = "0.2", optional = true }

# For instruction multiversioning
multiversion = { version = "0.6.1", optional = true }

# For SIMD utf8 validation
simdutf8 = "0.1.3"

[dev-dependencies]
rand = "0.8"
criterion = "0.3"
Expand Down
5 changes: 5 additions & 0 deletions benches/read_parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ fn add_benchmark(c: &mut Criterion) {
b.iter(|| read_decompressed_pages(&buffer, size * 8, 2).unwrap())
});

let a = format!("read utf8 large 2^{}", i);
c.bench_function(&a, |b| {
b.iter(|| read_decompressed_pages(&buffer, size * 8, 6).unwrap())
});

let a = format!("read bool 2^{}", i);
c.bench_function(&a, |b| {
b.iter(|| read_decompressed_pages(&buffer, size * 8, 3).unwrap())
Expand Down
3 changes: 3 additions & 0 deletions parquet_integration/write_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def case_basic_nullable(size=1):
float64 = [0.0, 1.0, None, 3.0, None, 5.0, 6.0, 7.0, None, 9.0]
string = ["Hello", None, "aa", "", None, "abc", None, None, "def", "aaa"]
boolean = [True, None, False, False, None, True, None, None, True, True]
string_large = ["ABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCD😃🌚🕳👊"] * 10

fields = [
pa.field("int64", pa.int64()),
Expand All @@ -18,6 +19,7 @@ def case_basic_nullable(size=1):
pa.field("bool", pa.bool_()),
pa.field("date", pa.timestamp("ms")),
pa.field("uint32", pa.uint32()),
pa.field("string_large", pa.utf8()),
]
schema = pa.schema(fields)

Expand All @@ -29,6 +31,7 @@ def case_basic_nullable(size=1):
"bool": boolean * size,
"date": int64 * size,
"uint32": int64 * size,
"string_large": string_large * size,
},
schema,
f"basic_nullable_{size*10}.parquet",
Expand Down
71 changes: 55 additions & 16 deletions src/array/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,51 +48,90 @@ where
}

fn compare_primitives<T: NativeType + Ord>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap().clone();
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
Box::new(move |i, j| total_cmp(&left.value(i), &right.value(j)))
}

fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<BooleanArray>().unwrap().clone();
let right = right.as_any().downcast_ref::<BooleanArray>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.clone();
Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
}

fn compare_f32(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<PrimitiveArray<f32>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<PrimitiveArray<f32>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<f32>>()
.unwrap().clone();
.unwrap()
.clone();
Box::new(move |i, j| total_cmp_f32(&left.value(i), &right.value(j)))
}

fn compare_f64(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<PrimitiveArray<f64>>()
.unwrap().clone();
.unwrap()
.clone();
Box::new(move |i, j| total_cmp_f64(&left.value(i), &right.value(j)))
}

fn compare_string<O: Offset>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<Utf8Array<O>>().unwrap().clone();
let right = right.as_any().downcast_ref::<Utf8Array<O>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<Utf8Array<O>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<Utf8Array<O>>()
.unwrap()
.clone();
Box::new(move |i, j| left.value(i).cmp(right.value(j)))
}

fn compare_binary<O: Offset>(left: &dyn Array, right: &dyn Array) -> DynComparator {
let left = left.as_any().downcast_ref::<BinaryArray<O>>().unwrap().clone();
let right = right.as_any().downcast_ref::<BinaryArray<O>>().unwrap().clone();
let left = left
.as_any()
.downcast_ref::<BinaryArray<O>>()
.unwrap()
.clone();
let right = right
.as_any()
.downcast_ref::<BinaryArray<O>>()
.unwrap()
.clone();
Box::new(move |i, j| left.value(i).cmp(right.value(j)))
}

fn compare_dict<K>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
) -> Result<DynComparator>
fn compare_dict<K>(left: &DictionaryArray<K>, right: &DictionaryArray<K>) -> Result<DynComparator>
where
K: DictionaryKey,
{
Expand Down
3 changes: 2 additions & 1 deletion src/array/specification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ pub fn check_offsets_and_utf8<O: Offset>(offsets: &[O], values: &[u8]) -> usize
let end = window[1].to_usize();
assert!(end <= values.len());
let slice = unsafe { std::slice::from_raw_parts(values.as_ptr().add(start), end - start) };
std::str::from_utf8(slice).expect("A non-utf8 string was passed.");
#[cfg(not(feature = "simdutf8"))]
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
simdutf8::basic::from_utf8(slice).expect("A non-utf8 string was passed.");
});
len
}
4 changes: 2 additions & 2 deletions src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ fn cast_with_options(
// perf todo: the offsets are equal; we can speed-up this
let iter = array
.iter()
.map(|x| x.and_then(|x| std::str::from_utf8(x).ok()));
.map(|x| x.and_then(|x| simdutf8::basic::from_utf8(x).ok()));

let array = Utf8Array::<i32>::from_trusted_len_iter(iter);
Ok(Box::new(array))
Expand Down Expand Up @@ -574,7 +574,7 @@ fn cast_with_options(
// perf todo: the offsets are equal; we can speed-up this
let iter = array
.iter()
.map(|x| x.and_then(|x| std::str::from_utf8(x).ok()));
.map(|x| x.and_then(|x| simdutf8::basic::from_utf8(x).ok()));

let array = Utf8Array::<i64>::from_trusted_len_iter(iter);
Ok(Box::new(array))
Expand Down
4 changes: 2 additions & 2 deletions src/compute/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ fn a_like_binary<O: Offset, F: Fn(bool) -> bool>(
let pattern = if let Some(pattern) = map.get(pattern) {
pattern
} else {
let re_pattern = std::str::from_utf8(pattern)
let re_pattern = simdutf8::basic::from_utf8(pattern)
.unwrap()
.replace("%", ".*")
.replace("_", ".");
Expand Down Expand Up @@ -231,7 +231,7 @@ fn a_like_binary_scalar<O: Offset, F: Fn(bool) -> bool>(
op: F,
) -> Result<BooleanArray> {
let validity = lhs.validity();
let pattern = std::str::from_utf8(rhs).map_err(|e| {
let pattern = simdutf8::basic::from_utf8(rhs).map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"Unable to convert the LIKE pattern to string: {}",
e
Expand Down
6 changes: 6 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ impl From<std::str::Utf8Error> for ArrowError {
}
}

impl From<simdutf8::basic::Utf8Error> for ArrowError {
fn from(error: simdutf8::basic::Utf8Error) -> Self {
ArrowError::External("".to_string(), Box::new(error))
}
}

impl Display for ArrowError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down
2 changes: 1 addition & 1 deletion src/ffi/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ unsafe fn read_ne_i32(ptr: *const u8) -> i32 {

unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str {
let slice = std::slice::from_raw_parts(ptr, len);
std::str::from_utf8(slice).unwrap()
simdutf8::basic::from_utf8(slice).unwrap()
}

unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) {
Expand Down
2 changes: 1 addition & 1 deletion src/io/avro/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub fn deserialize(mut block: &[u8], rows: usize, schema: Arc<Schema>) -> Result
"Avro format contains a non-usize number of bytes".to_string(),
)
})?;
let data = std::str::from_utf8(&block[..len])?;
let data = simdutf8::basic::from_utf8(&block[..len])?;
block = &block[len..];

let array = array
Expand Down
2 changes: 1 addition & 1 deletion src/io/avro/read/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn read_schema<R: Read>(reader: &mut R) -> AvroResult<(Schema, Codec, [u8; 1
.get("avro.codec")
.and_then(|codec| {
if let Value::Bytes(ref bytes) = *codec {
std::str::from_utf8(bytes.as_ref()).ok()
simdutf8::basic::from_utf8(bytes.as_ref()).ok()
} else {
None
}
Expand Down
14 changes: 7 additions & 7 deletions src/io/csv/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ where

fn deserialize_utf8<O: Offset>(rows: &[ByteRecord], column: usize) -> Arc<dyn Array> {
let iter = rows.iter().map(|row| match row.get(column) {
Some(bytes) => std::str::from_utf8(bytes).ok(),
Some(bytes) => simdutf8::basic::from_utf8(bytes).ok(),
None => None,
});
Arc::new(Utf8Array::<O>::from_trusted_len_iter(iter))
Expand Down Expand Up @@ -111,44 +111,44 @@ pub fn deserialize_column(
lexical_core::parse::<f64>(bytes).ok()
}),
Date32 => deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDate>().ok())
.map(|x| x.num_days_from_ce() - EPOCH_DAYS_FROM_CE)
}),
Date64 => deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_millis())
}),
Timestamp(TimeUnit::Nanosecond, None) => {
deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos())
})
}
Timestamp(TimeUnit::Microsecond, None) => {
deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos() / 1000)
})
}
Timestamp(TimeUnit::Millisecond, None) => {
deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos() / 1_000_000)
})
}
Timestamp(TimeUnit::Second, None) => {
deserialize_primitive(rows, column, datatype, |bytes| {
std::str::from_utf8(bytes)
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos() / 1_000_000_000)
Expand Down
2 changes: 1 addition & 1 deletion src/io/parquet/read/schema/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ mod tests {
.collect::<Vec<_>>();
assert_eq!(
names,
vec!["int64", "float64", "string", "bool", "date", "uint32"]
vec!["int64", "float64", "string", "bool", "date", "uint32", "string_large"]
);
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions src/io/parquet/read/statistics/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ impl TryFrom<&ParquetByteArrayStatistics> for Utf8Statistics {
min_value: stats
.min_value
.as_ref()
.map(|x| std::str::from_utf8(x).map(|x| x.to_string()))
.map(|x| simdutf8::basic::from_utf8(x).map(|x| x.to_string()))
.transpose()?,
max_value: stats
.max_value
.as_ref()
.map(|x| std::str::from_utf8(x).map(|x| x.to_string()))
.map(|x| simdutf8::basic::from_utf8(x).map(|x| x.to_string()))
.transpose()?,
})
}
Expand Down