Skip to content

Commit

Permalink
Improve autovectorization of to_lowercase / to_uppercase functions
Browse files Browse the repository at this point in the history
Refactor the code in the `convert_while_ascii` helper function to make
it more suitable for auto-vectorization and also process the full ascii
prefix of the string. The generic case conversion logic will only be
invoked starting from the first non-ascii character.

The runtime on microbenchmarks with ascii-only inputs improves between
1.5x for short and 4x for long inputs on x86_64 and aarch64.

The new implementation also encapsulates all unsafe inside the
`convert_while_ascii` function.

Fixes rust-lang#123712
  • Loading branch information
jhorstmann committed Jun 2, 2024
1 parent eda9d7f commit b03d939
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 50 deletions.
2 changes: 2 additions & 0 deletions library/alloc/benches/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,5 @@ make_test!(rsplitn_space_char, s, s.rsplitn(10, ' ').count());

make_test!(split_space_str, s, s.split(" ").count());
make_test!(split_ad_str, s, s.split("ad").count());

make_test!(to_lowercase, s, s.to_lowercase());
118 changes: 68 additions & 50 deletions library/alloc/src/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
use core::borrow::{Borrow, BorrowMut};
use core::iter::FusedIterator;
use core::mem;
use core::mem::MaybeUninit;
use core::ptr;
use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher};
use core::unicode::conversions;
Expand Down Expand Up @@ -366,14 +367,9 @@ impl str {
without modifying the original"]
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
pub fn to_lowercase(&self) -> String {
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase);

// Safety: we know this is a valid char boundary since
// out.len() is only progressed if ascii bytes are found
let rest = unsafe { self.get_unchecked(out.len()..) };

// Safety: We have written only valid ASCII to our vec
let mut s = unsafe { String::from_utf8_unchecked(out) };
let prefix_len = s.len();

for (i, c) in rest.char_indices() {
if c == 'Σ' {
Expand All @@ -382,8 +378,7 @@ impl str {
// in `SpecialCasing.txt`,
// so hard-code it rather than have a generic "condition" mechanism.
// See https://github.com/rust-lang/rust/issues/26035
let out_len = self.len() - rest.len();
let sigma_lowercase = map_uppercase_sigma(&self, i + out_len);
let sigma_lowercase = map_uppercase_sigma(self, prefix_len + i);
s.push(sigma_lowercase);
} else {
match conversions::to_lower(c) {
Expand Down Expand Up @@ -459,14 +454,7 @@ impl str {
without modifying the original"]
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
pub fn to_uppercase(&self) -> String {
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);

// Safety: we know this is a valid char boundary since
// out.len() is only progressed if ascii bytes are found
let rest = unsafe { self.get_unchecked(out.len()..) };

// Safety: We have written only valid ASCII to our vec
let mut s = unsafe { String::from_utf8_unchecked(out) };
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase);

for c in rest.chars() {
match conversions::to_upper(c) {
Expand Down Expand Up @@ -615,50 +603,80 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box<str> {
unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
}

/// Converts the bytes while the bytes are still ascii.
/// Converts leading ascii bytes in `s` by calling the `convert` function.
///
/// For better average performance, this happens in chunks of `2*size_of::<usize>()`.
/// Returns a vec with the converted bytes.
///
/// Returns a tuple of the converted prefix and the remainder starting from
/// the first non-ascii character.
#[inline]
#[cfg(not(test))]
#[cfg(not(no_global_oom_handling))]
fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec<u8> {
let mut out = Vec::with_capacity(b.len());

fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
// process the input in chunks to enable auto-vectorization
const USIZE_SIZE: usize = mem::size_of::<usize>();
const MAGIC_UNROLL: usize = 2;
const N: usize = USIZE_SIZE * MAGIC_UNROLL;
const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);

let mut i = 0;
unsafe {
while i + N <= b.len() {
// Safety: we have checks the sizes `b` and `out` to know that our
let in_chunk = b.get_unchecked(i..i + N);
let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);

let mut bits = 0;
for j in 0..MAGIC_UNROLL {
// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
// safety: in_chunk is valid bytes in the range
bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned();
}
// if our chunks aren't ascii, then return only the prior bytes as init
if bits & NONASCII_MASK != 0 {
break;
}
let mut slice = s.as_bytes();
let mut out = Vec::with_capacity(slice.len());
let mut out_slice = &mut out.spare_capacity_mut()[..slice.len()];

// perform the case conversions on N bytes (gets heavily autovec'd)
for j in 0..N {
// safety: in_chunk and out_chunk is valid bytes in the range
let out = out_chunk.get_unchecked_mut(j);
out.write(convert(in_chunk.get_unchecked(j)));
}
let mut ascii_prefix_len = 0_usize;
let mut is_ascii = [false; N];

while slice.len() >= N {
// Safety: checked in loop condition
let chunk = unsafe { slice.get_unchecked(..N) };
// Safety: out_slice has same length as input slice and gets sliced with the same offsets
let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };

for j in 0..N {
is_ascii[j] = chunk[j] <= 127;
}

// mark these bytes as initialised
i += N;
// auto-vectorization for this check is a bit fragile,
// sum and comparing against the chunk size gives the best result,
// specifically a pmovmsk instruction on x86.
if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
break;
}
out.set_len(i);

for j in 0..N {
out_chunk[j] = MaybeUninit::new(convert(&chunk[j]));
}

ascii_prefix_len += N;
slice = unsafe { slice.get_unchecked(N..) };
out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
}

// handle the remainder as individual bytes
while slice.len() > 0 {
let byte = slice[0];
if byte > 127 {
break;
}
// Safety: out_slice has same length as input slice and gets sliced with the same offsets
unsafe {
*out_slice.get_unchecked_mut(0) = MaybeUninit::new(convert(&byte));
}
ascii_prefix_len += 1;
slice = unsafe { slice.get_unchecked(1..) };
out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
}

out
unsafe {
// SAFETY: ascii_prefix_len bytes have been initialized above
out.set_len(ascii_prefix_len);

// SAFETY: We have written only valid ascii to the output vec
let ascii_string = String::from_utf8_unchecked(out);

// SAFETY: we know this is a valid char boundary
// since we only skipped over leading ascii bytes
let rest = core::str::from_utf8_unchecked(slice);

(ascii_string, rest)
}
}
17 changes: 17 additions & 0 deletions library/alloc/tests/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,19 @@ fn to_lowercase() {
assert_eq!("Α'Σ".to_lowercase(), "α'ς");
assert_eq!("Α''Σ".to_lowercase(), "α''ς");

assert_eq!("aΣ".to_lowercase(), "aς");
assert_eq!("a'Σ".to_lowercase(), "a'ς");
assert_eq!("a''Σ".to_lowercase(), "a''ς");

assert_eq!("ÄΣ".to_lowercase(), "äς");
assert_eq!("ä'Σ".to_lowercase(), "ä'ς");
assert_eq!("ä''Σ".to_lowercase(), "ä''ς");

// input lengths around the boundary of the chunk size used by the ascii prefix optimization
assert_eq!("abcdefghijklmnoΣ".to_lowercase(), "abcdefghijklmnoς");
assert_eq!("abcdefghijklmnopΣ".to_lowercase(), "abcdefghijklmnopς");
assert_eq!("abcdefghijklmnopqΣ".to_lowercase(), "abcdefghijklmnopqς");

assert_eq!("ΑΣ Α".to_lowercase(), "ας α");
assert_eq!("Α'Σ Α".to_lowercase(), "α'ς α");
assert_eq!("Α''Σ Α".to_lowercase(), "α''ς α");
Expand All @@ -1840,6 +1853,10 @@ fn to_lowercase() {
assert_eq!("Α 'Σ".to_lowercase(), "α 'σ");
assert_eq!("Α ''Σ".to_lowercase(), "α ''σ");

assert_eq!("Ä Σ".to_lowercase(), "ä σ");
assert_eq!("Ä 'Σ".to_lowercase(), "ä 'σ");
assert_eq!("Ä ''Σ".to_lowercase(), "ä ''σ");

assert_eq!("Σ".to_lowercase(), "σ");
assert_eq!("'Σ".to_lowercase(), "'σ");
assert_eq!("''Σ".to_lowercase(), "''σ");
Expand Down

0 comments on commit b03d939

Please sign in to comment.