Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stream ReadAtLeast and ReadExactly #69272

Merged
merged 16 commits into from
May 20, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal bool TryGetNextHeader(Stream archiveStream, bool copyData)
Span<byte> buffer = rented.AsSpan(0, TarHelpers.RecordSize); // minimumLength means the array could've been larger
buffer.Clear(); // Rented arrays aren't clean

TarHelpers.ReadOrThrow(archiveStream, buffer);
archiveStream.ReadExactly(buffer);

try
{
Expand Down Expand Up @@ -486,10 +486,7 @@ private void ReadExtendedAttributesBlock(Stream archiveStream)
}

byte[] buffer = new byte[(int)_size];
if (archiveStream.Read(buffer.AsSpan()) != _size)
{
throw new EndOfStreamException();
}
archiveStream.ReadExactly(buffer);

string dataAsString = TarHelpers.GetTrimmedUtf8String(buffer);

Expand Down Expand Up @@ -520,11 +517,7 @@ private void ReadGnuLongPathDataBlock(Stream archiveStream)
}

byte[] buffer = new byte[(int)_size];

if (archiveStream.Read(buffer.AsSpan()) != _size)
{
throw new EndOfStreamException();
}
archiveStream.ReadExactly(buffer);

string longPath = TarHelpers.GetTrimmedUtf8String(buffer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,6 @@ private static string GetTrimmedString(ReadOnlySpan<byte> buffer, Encoding encod
// removing the trailing null or space chars.
internal static string GetTrimmedUtf8String(ReadOnlySpan<byte> buffer) => GetTrimmedString(buffer, Encoding.UTF8);

// Reads the specified number of bytes and stores it in the byte buffer passed by reference.
// Throws if end of stream is reached.
internal static void ReadOrThrow(Stream archiveStream, Span<byte> buffer)
{
int totalRead = 0;
while (totalRead < buffer.Length)
{
int bytesRead = archiveStream.Read(buffer.Slice(totalRead));
if (bytesRead == 0)
{
throw new EndOfStreamException();
}
totalRead += bytesRead;
}
}

// Returns true if it successfully converts the specified string to a DateTimeOffset, false otherwise.
internal static bool TryConvertToDateTimeOffset(string value, out DateTimeOffset timestamp)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,10 @@ internal static Encoding GetEncoding(string text)
/// </summary>
internal static void ReadBytes(Stream stream, byte[] buffer, int bytesToRead)
{
int bytesLeftToRead = bytesToRead;

int totalBytesRead = 0;

while (bytesLeftToRead > 0)
int bytesRead = stream.ReadAtLeast(buffer.AsSpan(0, bytesToRead), bytesToRead, throwOnEndOfStream: false);
if (bytesRead < bytesToRead)
{
int bytesRead = stream.Read(buffer, totalBytesRead, bytesLeftToRead);
if (bytesRead == 0) throw new IOException(SR.UnexpectedEndOfStream);

totalBytesRead += bytesRead;
bytesLeftToRead -= bytesRead;
throw new IOException(SR.UnexpectedEndOfStream);
eerhardt marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,13 @@ private static ValueTask WriteAsync(Stream stream, Memory<byte> buffer, bool asy

private static async ValueTask ReadToFillAsync(Stream stream, Memory<byte> buffer, bool async)
{
while (buffer.Length != 0)
{
int bytesRead = async
? await stream.ReadAsync(buffer).ConfigureAwait(false)
: stream.Read(buffer.Span);
int bytesRead = async
? await stream.ReadAtLeastAsync(buffer, buffer.Length, throwOnEndOfStream: false).ConfigureAwait(false)
: stream.ReadAtLeast(buffer.Span, buffer.Length, throwOnEndOfStream: false);

if (bytesRead == 0)
{
throw new IOException(SR.net_http_invalid_response_premature_eof);
}

buffer = buffer[bytesRead..];
if (bytesRead < buffer.Length)
{
throw new IOException(SR.net_http_invalid_response_premature_eof);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,9 @@ void HandleExceptions(Exception ex)
}
}

private async ValueTask ReadToFillAsync(Stream stream, Memory<byte> buffer)
private ValueTask ReadToFillAsync(Stream stream, Memory<byte> buffer)
eerhardt marked this conversation as resolved.
Show resolved Hide resolved
{
while (!buffer.IsEmpty)
{
int bytesRead = await stream.ReadAsync(buffer).ConfigureAwait(false);
if (bytesRead == 0)
throw new Exception("Incomplete request");

buffer = buffer.Slice(bytesRead);
}
return stream.ReadExactlyAsync(buffer);
}

public async ValueTask DisposeAsync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,24 +424,15 @@ private async ValueTask<int> ReadAsync<TIOAdapter>(Memory<byte> buffer, Cancella

static async ValueTask<int> ReadAllAsync(Stream stream, Memory<byte> buffer, bool allowZeroRead, CancellationToken cancellationToken)
{
int read = 0;

do
int read = await TIOAdapter.ReadAtLeastAsync(
stream, buffer, buffer.Length, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (read < buffer.Length)
{
int bytes = await TIOAdapter.ReadAsync(stream, buffer, cancellationToken).ConfigureAwait(false);
if (bytes == 0)
if (read != 0 || !allowZeroRead)
{
if (read != 0 || !allowZeroRead)
{
throw new IOException(SR.net_io_eof);
}
break;
throw new IOException(SR.net_io_eof);
}

buffer = buffer.Slice(bytes);
read += bytes;
}
while (!buffer.IsEmpty);

return read;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace System.Net.Security
internal interface IReadWriteAdapter
{
static abstract ValueTask<int> ReadAsync(Stream stream, Memory<byte> buffer, CancellationToken cancellationToken);
static abstract ValueTask<int> ReadAtLeastAsync(Stream stream, Memory<byte> buffer, int minimumBytes, bool throwOnEndOfStream, CancellationToken cancellationToken);
static abstract ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken);
static abstract Task FlushAsync(Stream stream, CancellationToken cancellationToken);
static abstract Task WaitAsync(TaskCompletionSource<bool> waiter);
Expand All @@ -20,6 +21,9 @@ internal interface IReadWriteAdapter
public static ValueTask<int> ReadAsync(Stream stream, Memory<byte> buffer, CancellationToken cancellationToken) =>
stream.ReadAsync(buffer, cancellationToken);

public static ValueTask<int> ReadAtLeastAsync(Stream stream, Memory<byte> buffer, int minimumBytes, bool throwOnEndOfStream, CancellationToken cancellationToken) =>
stream.ReadAtLeastAsync(buffer, minimumBytes, throwOnEndOfStream, cancellationToken);

public static ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
stream.WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken);

Expand All @@ -33,6 +37,9 @@ public static ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int
public static ValueTask<int> ReadAsync(Stream stream, Memory<byte> buffer, CancellationToken cancellationToken) =>
new ValueTask<int>(stream.Read(buffer.Span));

public static ValueTask<int> ReadAtLeastAsync(Stream stream, Memory<byte> buffer, int minimumBytes, bool throwOnEndOfStream, CancellationToken cancellationToken) =>
new ValueTask<int>(stream.ReadAtLeast(buffer.Span, minimumBytes, throwOnEndOfStream));

public static ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
stream.Write(buffer, offset, count);
Expand Down
34 changes: 11 additions & 23 deletions src/libraries/System.Net.Security/src/System/Net/StreamFramer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,17 @@ internal sealed class StreamFramer

byte[] buffer = _readHeaderBuffer;

int bytesRead;
int offset = 0;
while (offset < buffer.Length)
int bytesRead = await TAdapter.ReadAtLeastAsync(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a risk here. Previously if the async read pends, it would force the ReadMessageAsync state machine to be allocated. Now if the async read pends, it will force the ReadMessageAsync state machine to be allocated and the ReadAtLeastAsync state machine to be allocated.

We should consider using the pooled async state machine builders on the ReadAtLeastAsync async methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this as simple as adding

[AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))]

to the Stream.ReadAtLeastAsyncCore method? Or is it more involved than that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation-wise, that's it, easy peasy. The important thing is validating performance. The concern with such pooling is that it can have a negative impact on the larger system, especially if it ends up creating more gen2 to gen0 references (by referencing gen0 objects from pooled objects that are likely to eventually be gen2). But I worry that without doing this, the methods will end up not being used because folks will see the allocation and choose to avoid the helpers as a result.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's chat tomorrow about how I can validate the performance here. I'd like to get your thoughts on what to try / look for.

stream, buffer, buffer.Length, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (bytesRead < buffer.Length)
{
bytesRead = await TAdapter.ReadAsync(stream, buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
if (offset == 0)
{
// m_Eof, return null
_eof = true;
return null;
}

throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
// m_Eof, return null
_eof = true;
return null;
}

offset += bytesRead;
throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
}

_curReadHeader.CopyFrom(buffer, 0);
Expand All @@ -61,16 +54,11 @@ internal sealed class StreamFramer

buffer = new byte[_curReadHeader.PayloadSize];

offset = 0;
while (offset < buffer.Length)
bytesRead = await TAdapter.ReadAtLeastAsync(
stream, buffer, buffer.Length, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (bytesRead < buffer.Length)
{
bytesRead = await TAdapter.ReadAsync(stream, buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
}

offset += bytesRead;
throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
}
return buffer;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -779,16 +779,18 @@ private async ValueTask<TResult> ReceiveAsyncPrivate<TResult>(Memory<byte> paylo
totalBytesReceived += receiveBufferBytesToCopy;
}

while (totalBytesReceived < limit)
if (totalBytesReceived < limit)
{
int numBytesRead = await _stream.ReadAsync(header.Compressed ?
_inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) :
payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived),
cancellationToken).ConfigureAwait(false);
if (numBytesRead <= 0)
int bytesToRead = limit - totalBytesReceived;
Memory<byte> readBuffer = header.Compressed ?
_inflater!.Memory.Slice(totalBytesReceived, bytesToRead) :
payloadBuffer.Slice(totalBytesReceived, bytesToRead);

int numBytesRead = await _stream.ReadAtLeastAsync(readBuffer, bytesToRead,
throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (numBytesRead < bytesToRead)
{
ThrowEOFUnexpected();
break;
}
totalBytesReceived += numBytesRead;
}
Expand Down Expand Up @@ -1359,17 +1361,13 @@ private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, Canc
_receiveBufferOffset = 0;

// While we don't have enough data, read more.
while (_receiveBufferCount < minimumRequiredBytes)
int numRead = await _stream.ReadAtLeastAsync(_receiveBuffer.Slice(_receiveBufferCount), minimumRequiredBytes,
throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (numRead < minimumRequiredBytes)
{
int numRead = await _stream.ReadAsync(_receiveBuffer.Slice(_receiveBufferCount), cancellationToken).ConfigureAwait(false);
Debug.Assert(numRead >= 0, $"Expected non-negative bytes read, got {numRead}");
if (numRead <= 0)
{
ThrowEOFUnexpected();
break;
}
_receiveBufferCount += numRead;
ThrowEOFUnexpected();
}
_receiveBufferCount += numRead;
}
}

Expand Down
39 changes: 4 additions & 35 deletions src/libraries/System.Private.CoreLib/src/System/IO/BinaryReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -483,18 +483,7 @@ public virtual byte[] ReadBytes(int count)
}

byte[] result = new byte[count];
int numRead = 0;
do
{
int n = _stream.Read(result, numRead, count);
if (n == 0)
{
break;
}

numRead += n;
count -= n;
} while (count > 0);
int numRead = _stream.ReadAtLeast(result, result.Length, throwOnEndOfStream: false);

if (numRead != result.Length)
{
Expand All @@ -521,16 +510,7 @@ private ReadOnlySpan<byte> InternalRead(int numBytes)
{
ThrowIfDisposed();

int bytesRead = 0;
do
{
int n = _stream.Read(_buffer, bytesRead, numBytes - bytesRead);
if (n == 0)
{
ThrowHelper.ThrowEndOfFileException();
}
bytesRead += n;
} while (bytesRead < numBytes);
_stream.ReadExactly(_buffer.AsSpan(0, numBytes));

return _buffer;
}
Expand All @@ -547,17 +527,14 @@ protected virtual void FillBuffer(int numBytes)
throw new ArgumentOutOfRangeException(nameof(numBytes), SR.ArgumentOutOfRange_BinaryReaderFillBuffer);
}

int bytesRead = 0;
int n;

ThrowIfDisposed();

// Need to find a good threshold for calling ReadByte() repeatedly
// vs. calling Read(byte[], int, int) for both buffered & unbuffered
// streams.
if (numBytes == 1)
{
n = _stream.ReadByte();
int n = _stream.ReadByte();
if (n == -1)
{
ThrowHelper.ThrowEndOfFileException();
Expand All @@ -567,15 +544,7 @@ protected virtual void FillBuffer(int numBytes)
return;
}

do
{
n = _stream.Read(_buffer, bytesRead, numBytes - bytesRead);
if (n == 0)
{
ThrowHelper.ThrowEndOfFileException();
}
bytesRead += n;
} while (bytesRead < numBytes);
_stream.ReadExactly(_buffer.AsSpan(0, numBytes));
}

public int Read7BitEncodedInt()
Expand Down
Loading