diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs index 0cff6c8eac66a..3f13b716cd4c7 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/NetworkStream.cs @@ -250,7 +250,17 @@ public override int Read(Span buffer) ThrowIfDisposed(); if (!CanRead) throw new InvalidOperationException(SR.net_writeonlystream); - int bytesRead = _streamSocket.Receive(buffer, SocketFlags.None, out SocketError errorCode); + int bytesRead; + SocketError errorCode; + try + { + bytesRead = _streamSocket.Receive(buffer, SocketFlags.None, out errorCode); + } + catch (Exception exception) when (!(exception is OutOfMemoryException)) + { + throw GetCustomException(SR.Format(SR.net_io_readfailure, exception.Message), exception); + } + if (errorCode != SocketError.Success) { var socketException = new SocketException((int)errorCode); @@ -320,7 +330,16 @@ public override void Write(ReadOnlySpan buffer) ThrowIfDisposed(); if (!CanWrite) throw new InvalidOperationException(SR.net_readonlystream); - _streamSocket.Send(buffer, SocketFlags.None, out SocketError errorCode); + SocketError errorCode; + try + { + _streamSocket.Send(buffer, SocketFlags.None, out errorCode); + } + catch (Exception exception) when (!(exception is OutOfMemoryException)) + { + throw GetCustomException(SR.Format(SR.net_io_writefailure, exception.Message), exception); + } + if (errorCode != SocketError.Success) { var socketException = new SocketException((int)errorCode); diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs index b812da8c7fd7e..4f03d4ab4c3ea 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/NetworkStreamTest.cs @@ -307,8 +307,10 @@ await RunWithConnectedNetworkStreamsAsync((server, _) => }); } - [Fact] - public async Task DisposeSocketDirectly_ReadWriteThrowNetworkException() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task DisposeSocketDirectly_ReadWriteThrowNetworkException(bool derivedNetworkStream) { using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) @@ -318,20 +320,22 @@ public async Task DisposeSocketDirectly_ReadWriteThrowNetworkException() Task acceptTask = listener.AcceptAsync(); await Task.WhenAll(acceptTask, client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listener.LocalEndPoint).Port))); - using (Socket serverSocket = await acceptTask) - using (DerivedNetworkStream server = new DerivedNetworkStream(serverSocket)) - { - serverSocket.Dispose(); + using Socket serverSocket = await acceptTask; + using NetworkStream server = derivedNetworkStream ? (NetworkStream)new DerivedNetworkStream(serverSocket) : new NetworkStream(serverSocket); + + serverSocket.Dispose(); - Assert.Throws(() => server.Read(new byte[1], 0, 1)); - Assert.Throws(() => server.Write(new byte[1], 0, 1)); + Assert.Throws(() => server.Read(new byte[1], 0, 1)); + Assert.Throws(() => server.Write(new byte[1], 0, 1)); - Assert.Throws(() => server.BeginRead(new byte[1], 0, 1, null, null)); - Assert.Throws(() => server.BeginWrite(new byte[1], 0, 1, null, null)); + Assert.Throws(() => server.Read((Span)new byte[1])); + Assert.Throws(() => server.Write((ReadOnlySpan)new byte[1])); - Assert.Throws(() => { server.ReadAsync(new byte[1], 0, 1); }); - Assert.Throws(() => { server.WriteAsync(new byte[1], 0, 1); }); - } + Assert.Throws(() => server.BeginRead(new byte[1], 0, 1, null, null)); + Assert.Throws(() => server.BeginWrite(new byte[1], 0, 1, null, null)); + + Assert.Throws(() => { server.ReadAsync(new byte[1], 0, 1); }); + Assert.Throws(() => { server.WriteAsync(new byte[1], 0, 1); }); } }