Skip to content

Commit

Permalink
Rollup merge of rust-lang#101193 - thomcc:win-stdio-nozero, r=ChrisDe…
Browse files Browse the repository at this point in the history
…nton

Avoid zeroing large stack buffers in stdio on Windows

Does what it says on the tin, using `[MaybeUninit<u16>; N]` instead of `[0u16; N]`. These buffers seem to be around 8kb, which is big enough that this is likely to be a very nice perf boost to stdio-heavy windows code.

r? `@ChrisDenton`

*(Note: this PR also has a commit that adds windows to CI, but as it mentions I'll revert that after it comes out green -- I can only do a check build on the machine I'm typing this on)*
  • Loading branch information
JohnTitor authored Aug 30, 2022
2 parents 9337132 + 1b8b2dc commit 1ca7aaa
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions library/std/src/sys/windows/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::char::decode_utf16;
use crate::cmp;
use crate::io;
use crate::mem::MaybeUninit;
use crate::os::windows::io::{FromRawHandle, IntoRawHandle};
use crate::ptr;
use crate::str;
Expand Down Expand Up @@ -169,13 +170,14 @@ fn write(
}

fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usize> {
let mut utf16 = [0u16; MAX_BUFFER_SIZE / 2];
let mut utf16 = [MaybeUninit::<u16>::uninit(); MAX_BUFFER_SIZE / 2];
let mut len_utf16 = 0;
for (chr, dest) in utf8.encode_utf16().zip(utf16.iter_mut()) {
*dest = chr;
*dest = MaybeUninit::new(chr);
len_utf16 += 1;
}
let utf16 = &utf16[..len_utf16];
// Safety: We've initialized `len_utf16` values.
let utf16: &[u16] = unsafe { MaybeUninit::slice_assume_init_ref(&utf16[..len_utf16]) };

let mut written = write_u16s(handle, &utf16)?;

Expand Down Expand Up @@ -250,27 +252,33 @@ impl io::Read for Stdin {
return Ok(bytes_copied);
} else if buf.len() - bytes_copied < 4 {
// Not enough space to get a UTF-8 byte. We will use the incomplete UTF8.
let mut utf16_buf = [0u16; 1];
let mut utf16_buf = [MaybeUninit::new(0); 1];
// Read one u16 character.
let read = read_u16s_fixup_surrogates(handle, &mut utf16_buf, 1, &mut self.surrogate)?;
// Read bytes, using the (now-empty) self.incomplete_utf8 as extra space.
let read_bytes = utf16_to_utf8(&utf16_buf[..read], &mut self.incomplete_utf8.bytes)?;
let read_bytes = utf16_to_utf8(
unsafe { MaybeUninit::slice_assume_init_ref(&utf16_buf[..read]) },
&mut self.incomplete_utf8.bytes,
)?;

// Read in the bytes from incomplete_utf8 until the buffer is full.
self.incomplete_utf8.len = read_bytes as u8;
// No-op if no bytes.
bytes_copied += self.incomplete_utf8.read(&mut buf[bytes_copied..]);
Ok(bytes_copied)
} else {
let mut utf16_buf = [0u16; MAX_BUFFER_SIZE / 2];
let mut utf16_buf = [MaybeUninit::<u16>::uninit(); MAX_BUFFER_SIZE / 2];

// In the worst case, a UTF-8 string can take 3 bytes for every `u16` of a UTF-16. So
// we can read at most a third of `buf.len()` chars and uphold the guarantee no data gets
// lost.
let amount = cmp::min(buf.len() / 3, utf16_buf.len());
let read =
read_u16s_fixup_surrogates(handle, &mut utf16_buf, amount, &mut self.surrogate)?;

match utf16_to_utf8(&utf16_buf[..read], buf) {
// Safety `read_u16s_fixup_surrogates` returns the number of items
// initialized.
let utf16s = unsafe { MaybeUninit::slice_assume_init_ref(&utf16_buf[..read]) };
match utf16_to_utf8(utf16s, buf) {
Ok(value) => return Ok(bytes_copied + value),
Err(e) => return Err(e),
}
Expand All @@ -283,14 +291,14 @@ impl io::Read for Stdin {
// This is a best effort, and might not work if we are not the only reader on Stdin.
fn read_u16s_fixup_surrogates(
handle: c::HANDLE,
buf: &mut [u16],
buf: &mut [MaybeUninit<u16>],
mut amount: usize,
surrogate: &mut u16,
) -> io::Result<usize> {
// Insert possibly remaining unpaired surrogate from last read.
let mut start = 0;
if *surrogate != 0 {
buf[0] = *surrogate;
buf[0] = MaybeUninit::new(*surrogate);
*surrogate = 0;
start = 1;
if amount == 1 {
Expand All @@ -303,7 +311,10 @@ fn read_u16s_fixup_surrogates(
let mut amount = read_u16s(handle, &mut buf[start..amount])? + start;

if amount > 0 {
let last_char = buf[amount - 1];
// Safety: The returned `amount` is the number of values initialized,
// and it is not 0, so we know that `buf[amount - 1]` have been
// initialized.
let last_char = unsafe { buf[amount - 1].assume_init() };
if last_char >= 0xD800 && last_char <= 0xDBFF {
// high surrogate
*surrogate = last_char;
Expand All @@ -313,7 +324,8 @@ fn read_u16s_fixup_surrogates(
Ok(amount)
}

fn read_u16s(handle: c::HANDLE, buf: &mut [u16]) -> io::Result<usize> {
// Returns `Ok(n)` if it initialized `n` values in `buf`.
fn read_u16s(handle: c::HANDLE, buf: &mut [MaybeUninit<u16>]) -> io::Result<usize> {
// Configure the `pInputControl` parameter to not only return on `\r\n` but also Ctrl-Z, the
// traditional DOS method to indicate end of character stream / user input (SUB).
// See #38274 and https://stackoverflow.com/questions/43836040/win-api-readconsole.
Expand Down Expand Up @@ -346,8 +358,9 @@ fn read_u16s(handle: c::HANDLE, buf: &mut [u16]) -> io::Result<usize> {
}
break;
}

if amount > 0 && buf[amount as usize - 1] == CTRL_Z {
// Safety: if `amount > 0`, then that many bytes were written, so
// `buf[amount as usize - 1]` has been initialized.
if amount > 0 && unsafe { buf[amount as usize - 1].assume_init() } == CTRL_Z {
amount -= 1;
}
Ok(amount as usize)
Expand Down

0 comments on commit 1ca7aaa

Please sign in to comment.