Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stream ReadAtLeast and ReadExactly #69272

Merged
merged 16 commits into from
May 20, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private async Task ProcessRequest(Socket clientSocket, NetworkStream ns)
private async Task ProcessSocks4Request(Socket clientSocket, NetworkStream ns)
{
byte[] buffer = new byte[7];
await ReadToFillAsync(ns, buffer).ConfigureAwait(false);
await ns.ReadExactlyAsync(buffer).ConfigureAwait(false);

if (buffer[0] != 1)
throw new Exception("Only CONNECT is supported.");
Expand Down Expand Up @@ -148,7 +148,7 @@ private async Task ProcessSocks5Request(Socket clientSocket, NetworkStream ns)
throw new Exception("Early EOF");

byte[] buffer = new byte[1024];
await ReadToFillAsync(ns, buffer.AsMemory(0, nMethods)).ConfigureAwait(false);
await ns.ReadExactlyAsync(buffer.AsMemory(0, nMethods)).ConfigureAwait(false);

byte expectedAuthMethod = _username == null ? (byte)0 : (byte)2;
if (!buffer.AsSpan(0, nMethods).Contains(expectedAuthMethod))
Expand All @@ -165,11 +165,11 @@ private async Task ProcessSocks5Request(Socket clientSocket, NetworkStream ns)
throw new Exception("Bad subnegotiation version.");

int usernameLength = await ns.ReadByteAsync().ConfigureAwait(false);
await ReadToFillAsync(ns, buffer.AsMemory(0, usernameLength)).ConfigureAwait(false);
await ns.ReadExactlyAsync(buffer.AsMemory(0, usernameLength)).ConfigureAwait(false);
string username = Encoding.UTF8.GetString(buffer.AsSpan(0, usernameLength));

int passwordLength = await ns.ReadByteAsync().ConfigureAwait(false);
await ReadToFillAsync(ns, buffer.AsMemory(0, passwordLength)).ConfigureAwait(false);
await ns.ReadExactlyAsync(buffer.AsMemory(0, passwordLength)).ConfigureAwait(false);
string password = Encoding.UTF8.GetString(buffer.AsSpan(0, passwordLength));

if (username != _username || password != _password)
Expand All @@ -181,7 +181,7 @@ private async Task ProcessSocks5Request(Socket clientSocket, NetworkStream ns)
await ns.WriteAsync(new byte[] { 1, 0 }).ConfigureAwait(false);
}

await ReadToFillAsync(ns, buffer.AsMemory(0, 4)).ConfigureAwait(false);
await ns.ReadExactlyAsync(buffer.AsMemory(0, 4)).ConfigureAwait(false);
if (buffer[0] != 5)
throw new Exception("Bad protocol version.");
if (buffer[1] != 1)
Expand All @@ -191,26 +191,26 @@ private async Task ProcessSocks5Request(Socket clientSocket, NetworkStream ns)
switch (buffer[3])
{
case 1:
await ReadToFillAsync(ns, buffer.AsMemory(0, 4)).ConfigureAwait(false);
await ns.ReadExactlyAsync(buffer.AsMemory(0, 4)).ConfigureAwait(false);
remoteHost = new IPAddress(buffer.AsSpan(0, 4)).ToString();
break;
case 4:
await ReadToFillAsync(ns, buffer.AsMemory(0, 16)).ConfigureAwait(false);
await ns.ReadExactlyAsync(buffer.AsMemory(0, 16)).ConfigureAwait(false);
remoteHost = new IPAddress(buffer.AsSpan(0, 16)).ToString();
break;
case 3:
int length = await ns.ReadByteAsync().ConfigureAwait(false);
if (length == -1)
throw new Exception("Early EOF");
await ReadToFillAsync(ns, buffer.AsMemory(0, length)).ConfigureAwait(false);
await ns.ReadExactlyAsync(buffer.AsMemory(0, length)).ConfigureAwait(false);
remoteHost = Encoding.UTF8.GetString(buffer.AsSpan(0, length));
break;

default:
throw new Exception("Unknown address type.");
}

await ReadToFillAsync(ns, buffer.AsMemory(0, 2)).ConfigureAwait(false);
await ns.ReadExactlyAsync(buffer.AsMemory(0, 2)).ConfigureAwait(false);
int port = (buffer[0] << 8) + buffer[1];

await ns.WriteAsync(new byte[] { 5, 0, 0, 1, 0, 0, 0, 0, 0, 0 }).ConfigureAwait(false);
Expand Down Expand Up @@ -290,11 +290,6 @@ void HandleExceptions(Exception ex)
}
}

private ValueTask ReadToFillAsync(Stream stream, Memory<byte> buffer)
{
return stream.ReadExactlyAsync(buffer);
}

public async ValueTask DisposeAsync()
{
_listener.Dispose();
Expand Down
32 changes: 28 additions & 4 deletions src/libraries/System.Private.CoreLib/src/System/IO/Stream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;

namespace System.IO
{
Expand Down Expand Up @@ -332,14 +333,36 @@ static async ValueTask<int> FinishReadAsync(Task<int> readTask, byte[] localBuff
}
}

public async ValueTask ReadExactlyAsync(Memory<byte> buffer, CancellationToken cancellationToken = default) =>
_ = await ReadAtLeastAsyncCore(buffer, buffer.Length, throwOnEndOfStream: true, cancellationToken).ConfigureAwait(false);
public ValueTask ReadExactlyAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
ValueTask<int> vt = ReadAtLeastAsyncCore(buffer, buffer.Length, throwOnEndOfStream: true, cancellationToken);
if (vt.IsCompletedSuccessfully)
return default;

// use the ValueTask<int>'s backing object to create a ValueTask without allocating here.
object? obj = vt._obj;
Debug.Assert(obj is Task || obj is IValueTaskSource);

return obj is Task task ?
new ValueTask(task) :
new ValueTask((IValueTaskSource)obj, vt._token);
eerhardt marked this conversation as resolved.
Show resolved Hide resolved
}

public async ValueTask ReadExactlyAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
public ValueTask ReadExactlyAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
{
ValidateBufferArguments(buffer, offset, count);
eerhardt marked this conversation as resolved.
Show resolved Hide resolved

_ = await ReadAtLeastAsyncCore(buffer.AsMemory(offset, count), count, throwOnEndOfStream: true, cancellationToken).ConfigureAwait(false);
ValueTask<int> vt = ReadAtLeastAsyncCore(buffer.AsMemory(offset, count), count, throwOnEndOfStream: true, cancellationToken);
if (vt.IsCompletedSuccessfully)
return default;

// use the ValueTask<int>'s backing object to create a ValueTask without allocating here.
object? obj = vt._obj;
Debug.Assert(obj is Task || obj is IValueTaskSource);

return obj is Task task ?
new ValueTask(task) :
new ValueTask((IValueTaskSource)obj, vt._token);
}

public ValueTask<int> ReadAtLeastAsync(Memory<byte> buffer, int minimumBytes, bool throwOnEndOfStream = true, CancellationToken cancellationToken = default)
Expand All @@ -350,6 +373,7 @@ public ValueTask<int> ReadAtLeastAsync(Memory<byte> buffer, int minimumBytes, bo
}

// No argument checking is done here. It is up to the caller.
[AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
eerhardt marked this conversation as resolved.
Show resolved Hide resolved
private async ValueTask<int> ReadAtLeastAsyncCore(Memory<byte> buffer, int minimumBytes, bool throwOnEndOfStream, CancellationToken cancellationToken)
{
int totalRead = 0;
Expand Down