Skip to content

Commit

Permalink
Add Stream ReadAtLeast and ReadExactly
Browse files Browse the repository at this point in the history
Adds methods to Stream to read at least a minimum amount of bytes, or a full buffer, of data from the stream.
ReadAtLeast allows for the caller to specify whether an exception should be thrown or not on the end of the stream.

Make use of the new methods where appropriate in net7.0.

Fix dotnet#16598
  • Loading branch information
eerhardt committed May 12, 2022
1 parent 844b298 commit 3c6cd7e
Show file tree
Hide file tree
Showing 20 changed files with 186 additions and 181 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal bool TryGetNextHeader(Stream archiveStream, bool copyData)
Span<byte> buffer = rented.AsSpan(0, TarHelpers.RecordSize); // minimumLength means the array could've been larger
buffer.Clear(); // Rented arrays aren't clean

TarHelpers.ReadOrThrow(archiveStream, buffer);
archiveStream.ReadExactly(buffer);

try
{
Expand Down Expand Up @@ -486,10 +486,7 @@ private void ReadExtendedAttributesBlock(Stream archiveStream)
}

byte[] buffer = new byte[(int)_size];
if (archiveStream.Read(buffer.AsSpan()) != _size)
{
throw new EndOfStreamException();
}
archiveStream.ReadExactly(buffer);

string dataAsString = TarHelpers.GetTrimmedUtf8String(buffer);

Expand Down Expand Up @@ -520,11 +517,7 @@ private void ReadGnuLongPathDataBlock(Stream archiveStream)
}

byte[] buffer = new byte[(int)_size];

if (archiveStream.Read(buffer.AsSpan()) != _size)
{
throw new EndOfStreamException();
}
archiveStream.ReadExactly(buffer);

string longPath = TarHelpers.GetTrimmedUtf8String(buffer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,6 @@ private static string GetTrimmedString(ReadOnlySpan<byte> buffer, Encoding encod
// removing the trailing null or space chars.
internal static string GetTrimmedUtf8String(ReadOnlySpan<byte> buffer) => GetTrimmedString(buffer, Encoding.UTF8);

// Reads the specified number of bytes and stores it in the byte buffer passed by reference.
// Throws if end of stream is reached.
internal static void ReadOrThrow(Stream archiveStream, Span<byte> buffer)
{
int totalRead = 0;
while (totalRead < buffer.Length)
{
int bytesRead = archiveStream.Read(buffer.Slice(totalRead));
if (bytesRead == 0)
{
throw new EndOfStreamException();
}
totalRead += bytesRead;
}
}

// Returns true if it successfully converts the specified string to a DateTimeOffset, false otherwise.
internal static bool TryConvertToDateTimeOffset(string value, out DateTimeOffset timestamp)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,10 @@ internal static Encoding GetEncoding(string text)
/// </summary>
internal static void ReadBytes(Stream stream, byte[] buffer, int bytesToRead)
{
int bytesLeftToRead = bytesToRead;

int totalBytesRead = 0;

while (bytesLeftToRead > 0)
int bytesRead = stream.ReadAtLeast(buffer.AsSpan(0, bytesToRead), bytesToRead, throwOnEndOfStream: false);
if (bytesRead < bytesToRead)
{
int bytesRead = stream.Read(buffer, totalBytesRead, bytesLeftToRead);
if (bytesRead == 0) throw new IOException(SR.UnexpectedEndOfStream);

totalBytesRead += bytesRead;
bytesLeftToRead -= bytesRead;
throw new IOException(SR.UnexpectedEndOfStream);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,13 @@ private static ValueTask WriteAsync(Stream stream, Memory<byte> buffer, bool asy

private static async ValueTask ReadToFillAsync(Stream stream, Memory<byte> buffer, bool async)
{
while (buffer.Length != 0)
{
int bytesRead = async
? await stream.ReadAsync(buffer).ConfigureAwait(false)
: stream.Read(buffer.Span);
int bytesRead = async
? await stream.ReadAtLeastAsync(buffer, buffer.Length, throwOnEndOfStream: false).ConfigureAwait(false)
: stream.ReadAtLeast(buffer.Span, buffer.Length, throwOnEndOfStream: false);

if (bytesRead == 0)
{
throw new IOException(SR.net_http_invalid_response_premature_eof);
}

buffer = buffer[bytesRead..];
if (bytesRead < buffer.Length)
{
throw new IOException(SR.net_http_invalid_response_premature_eof);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,9 @@ void HandleExceptions(Exception ex)
}
}

private async ValueTask ReadToFillAsync(Stream stream, Memory<byte> buffer)
private ValueTask ReadToFillAsync(Stream stream, Memory<byte> buffer)
{
while (!buffer.IsEmpty)
{
int bytesRead = await stream.ReadAsync(buffer).ConfigureAwait(false);
if (bytesRead == 0)
throw new Exception("Incomplete request");

buffer = buffer.Slice(bytesRead);
}
return stream.ReadExactlyAsync(buffer);
}

public async ValueTask DisposeAsync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,24 +424,15 @@ private async ValueTask<int> ReadAsync<TIOAdapter>(Memory<byte> buffer, Cancella

static async ValueTask<int> ReadAllAsync(Stream stream, Memory<byte> buffer, bool allowZeroRead, CancellationToken cancellationToken)
{
int read = 0;

do
int read = await TIOAdapter.ReadAtLeastAsync(
stream, buffer, buffer.Length, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (read < buffer.Length)
{
int bytes = await TIOAdapter.ReadAsync(stream, buffer, cancellationToken).ConfigureAwait(false);
if (bytes == 0)
if (read != 0 || !allowZeroRead)
{
if (read != 0 || !allowZeroRead)
{
throw new IOException(SR.net_io_eof);
}
break;
throw new IOException(SR.net_io_eof);
}

buffer = buffer.Slice(bytes);
read += bytes;
}
while (!buffer.IsEmpty);

return read;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace System.Net.Security
internal interface IReadWriteAdapter
{
static abstract ValueTask<int> ReadAsync(Stream stream, Memory<byte> buffer, CancellationToken cancellationToken);
static abstract ValueTask<int> ReadAtLeastAsync(Stream stream, Memory<byte> buffer, int minimumBytes, bool throwOnEndOfStream, CancellationToken cancellationToken);
static abstract ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken);
static abstract Task FlushAsync(Stream stream, CancellationToken cancellationToken);
static abstract Task WaitAsync(TaskCompletionSource<bool> waiter);
Expand All @@ -20,6 +21,9 @@ internal interface IReadWriteAdapter
public static ValueTask<int> ReadAsync(Stream stream, Memory<byte> buffer, CancellationToken cancellationToken) =>
stream.ReadAsync(buffer, cancellationToken);

public static ValueTask<int> ReadAtLeastAsync(Stream stream, Memory<byte> buffer, int minimumBytes, bool throwOnEndOfStream, CancellationToken cancellationToken) =>
stream.ReadAtLeastAsync(buffer, minimumBytes, throwOnEndOfStream, cancellationToken);

public static ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
stream.WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken);

Expand All @@ -33,6 +37,9 @@ public static ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int
public static ValueTask<int> ReadAsync(Stream stream, Memory<byte> buffer, CancellationToken cancellationToken) =>
new ValueTask<int>(stream.Read(buffer.Span));

public static ValueTask<int> ReadAtLeastAsync(Stream stream, Memory<byte> buffer, int minimumBytes, bool throwOnEndOfStream, CancellationToken cancellationToken) =>
new ValueTask<int>(stream.ReadAtLeast(buffer.Span, minimumBytes, throwOnEndOfStream));

public static ValueTask WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
stream.Write(buffer, offset, count);
Expand Down
34 changes: 11 additions & 23 deletions src/libraries/System.Net.Security/src/System/Net/StreamFramer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,17 @@ internal sealed class StreamFramer

byte[] buffer = _readHeaderBuffer;

int bytesRead;
int offset = 0;
while (offset < buffer.Length)
int bytesRead = await TAdapter.ReadAtLeastAsync(
stream, buffer, buffer.Length, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (bytesRead < buffer.Length)
{
bytesRead = await TAdapter.ReadAsync(stream, buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
if (offset == 0)
{
// m_Eof, return null
_eof = true;
return null;
}

throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
// m_Eof, return null
_eof = true;
return null;
}

offset += bytesRead;
throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
}

_curReadHeader.CopyFrom(buffer, 0);
Expand All @@ -61,16 +54,11 @@ internal sealed class StreamFramer

buffer = new byte[_curReadHeader.PayloadSize];

offset = 0;
while (offset < buffer.Length)
bytesRead = await TAdapter.ReadAtLeastAsync(
stream, buffer, buffer.Length, throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (bytesRead < buffer.Length)
{
bytesRead = await TAdapter.ReadAsync(stream, buffer.AsMemory(offset), cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
}

offset += bytesRead;
throw new IOException(SR.Format(SR.net_io_readfailure, SR.net_io_connectionclosed));
}
return buffer;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -779,13 +779,16 @@ private async ValueTask<TResult> ReceiveAsyncPrivate<TResult>(Memory<byte> paylo
totalBytesReceived += receiveBufferBytesToCopy;
}

while (totalBytesReceived < limit)
if (totalBytesReceived < limit)
{
int numBytesRead = await _stream.ReadAsync(header.Compressed ?
_inflater!.Memory.Slice(totalBytesReceived, limit - totalBytesReceived) :
payloadBuffer.Slice(totalBytesReceived, limit - totalBytesReceived),
cancellationToken).ConfigureAwait(false);
if (numBytesRead <= 0)
int bytesToRead = limit - totalBytesReceived;
Memory<byte> readBuffer = header.Compressed ?
_inflater!.Memory.Slice(totalBytesReceived, bytesToRead) :
payloadBuffer.Slice(totalBytesReceived, bytesToRead);

int numBytesRead = await _stream.ReadAtLeastAsync(readBuffer, bytesToRead,
throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (numBytesRead < bytesToRead)
{
ThrowEOFUnexpected();
break;
Expand Down Expand Up @@ -1359,17 +1362,13 @@ private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, Canc
_receiveBufferOffset = 0;

// While we don't have enough data, read more.
while (_receiveBufferCount < minimumRequiredBytes)
int numRead = await _stream.ReadAtLeastAsync(_receiveBuffer.Slice(_receiveBufferCount), minimumRequiredBytes,
throwOnEndOfStream: false, cancellationToken).ConfigureAwait(false);
if (numRead < minimumRequiredBytes)
{
int numRead = await _stream.ReadAsync(_receiveBuffer.Slice(_receiveBufferCount), cancellationToken).ConfigureAwait(false);
Debug.Assert(numRead >= 0, $"Expected non-negative bytes read, got {numRead}");
if (numRead <= 0)
{
ThrowEOFUnexpected();
break;
}
_receiveBufferCount += numRead;
ThrowEOFUnexpected();
}
_receiveBufferCount += numRead;
}
}

Expand Down
39 changes: 4 additions & 35 deletions src/libraries/System.Private.CoreLib/src/System/IO/BinaryReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -483,18 +483,7 @@ public virtual byte[] ReadBytes(int count)
}

byte[] result = new byte[count];
int numRead = 0;
do
{
int n = _stream.Read(result, numRead, count);
if (n == 0)
{
break;
}

numRead += n;
count -= n;
} while (count > 0);
int numRead = _stream.ReadAtLeast(result, result.Length, throwOnEndOfStream: false);

if (numRead != result.Length)
{
Expand All @@ -521,16 +510,7 @@ private ReadOnlySpan<byte> InternalRead(int numBytes)
{
ThrowIfDisposed();

int bytesRead = 0;
do
{
int n = _stream.Read(_buffer, bytesRead, numBytes - bytesRead);
if (n == 0)
{
ThrowHelper.ThrowEndOfFileException();
}
bytesRead += n;
} while (bytesRead < numBytes);
_stream.ReadExactly(_buffer.AsSpan(0, numBytes));

return _buffer;
}
Expand All @@ -547,17 +527,14 @@ protected virtual void FillBuffer(int numBytes)
throw new ArgumentOutOfRangeException(nameof(numBytes), SR.ArgumentOutOfRange_BinaryReaderFillBuffer);
}

int bytesRead = 0;
int n;

ThrowIfDisposed();

// Need to find a good threshold for calling ReadByte() repeatedly
// vs. calling Read(byte[], int, int) for both buffered & unbuffered
// streams.
if (numBytes == 1)
{
n = _stream.ReadByte();
int n = _stream.ReadByte();
if (n == -1)
{
ThrowHelper.ThrowEndOfFileException();
Expand All @@ -567,15 +544,7 @@ protected virtual void FillBuffer(int numBytes)
return;
}

do
{
n = _stream.Read(_buffer, bytesRead, numBytes - bytesRead);
if (n == 0)
{
ThrowHelper.ThrowEndOfFileException();
}
bytesRead += n;
} while (bytesRead < numBytes);
_stream.ReadExactly(_buffer.AsSpan(0, numBytes));
}

public int Read7BitEncodedInt()
Expand Down
Loading

0 comments on commit 3c6cd7e

Please sign in to comment.