Skip to content

Commit

Permalink
Fix ConnectionResetError not being raised when the transport is closed (
Browse files Browse the repository at this point in the history
#7180)

<!-- Thank you for your contribution! -->

`ConnectionResetError` will always be raised when `StreamWriter.write`
is called after `connection_lost` has been called on the `BaseProtocol`

<!-- Please give a short brief about these changes. -->

Restores pre 3.8.3 behavior

fixes #7172

- [x] I think the code is well written
- [x] Unit tests for the changes exist
- [x] Documentation reflects the changes
- [x] If you provide code modification, please add yourself to
`CONTRIBUTORS.txt`
  * The format is &lt;Name&gt; &lt;Surname&gt;.
  * Please keep alphabetical order, the file is sorted by names.
- [x] Add a new news fragment into the `CHANGES` folder
  * name it `<issue_id>.<type>` for example (588.bugfix)
* if you don't have an `issue_id` change it to the pr id after creating
the pr
  * ensure type is one of the following:
    * `.feature`: Signifying a new feature.
    * `.bugfix`: Signifying a bug fix.
    * `.doc`: Signifying a documentation improvement.
    * `.removal`: Signifying a deprecation or removal of public API.
* `.misc`: A ticket has been closed, but it is not of interest to users.
* Make sure to use full sentences with correct case and punctuation, for
example: "Fix issue with non-ascii contents in doctest text files."

---------

Co-authored-by: Sviatoslav Sydorenko <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sam Bull <[email protected]>
(cherry picked from commit 974323f)
  • Loading branch information
bdraco authored and Dreamsorcerer committed Feb 10, 2023
1 parent e0a7865 commit b7d33c9
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGES/7180.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``ConnectionResetError`` will always be raised when ``StreamWriter.write`` is called after ``connection_lost`` has been called on the ``BaseProtocol``
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Ilya Gruzinov
Ingmar Steen
Ivan Lakovic
Ivan Larin
J. Nick Koston
Jacob Champion
Jaesung Lee
Jake Davis
Expand Down
9 changes: 6 additions & 3 deletions aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop: asyncio.AbstractEventLoop = loop
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
self._connection_lost = False
self._reading_paused = False

self.transport: Optional[asyncio.Transport] = None

@property
def connected(self) -> bool:
"""Return True if the connection is open."""
return self.transport is not None

def pause_writing(self) -> None:
assert not self._paused
self._paused = True
Expand Down Expand Up @@ -59,7 +63,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.transport = tr

def connection_lost(self, exc: Optional[BaseException]) -> None:
self._connection_lost = True
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
Expand All @@ -76,7 +79,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
waiter.set_exception(exc)

async def _drain_helper(self) -> None:
if self._connection_lost:
if not self.connected:
raise ConnectionResetError("Connection lost")
if not self._paused:
return
Expand Down
10 changes: 4 additions & 6 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
on_headers_sent: _T_OnHeadersSent = None,
) -> None:
self._protocol = protocol
self._transport = protocol.transport

self.loop = loop
self.length = None
Expand All @@ -52,7 +51,7 @@ def __init__(

@property
def transport(self) -> Optional[asyncio.Transport]:
return self._transport
return self._protocol.transport

@property
def protocol(self) -> BaseProtocol:
Expand All @@ -71,10 +70,10 @@ def _write(self, chunk: bytes) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size

if self._transport is None or self._transport.is_closing():
transport = self.transport
if not self._protocol.connected or transport is None or transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
self._transport.write(chunk)
transport.write(chunk)

async def write(
self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
Expand Down Expand Up @@ -159,7 +158,6 @@ async def write_eof(self, chunk: bytes = b"") -> None:
await self.drain()

self._eof = True
self._transport = None

async def drain(self) -> None:
"""Flush the write buffer.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ async def test_connection_lost_not_paused() -> None:
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
assert pr.connected
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost
assert not pr.connected


async def test_connection_lost_paused_without_waiter() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
assert pr.connected
pr.pause_writing()
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost
assert not pr.connected


async def test_drain_lost() -> None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,17 @@ async def test_eof_received(loop) -> None:
assert proto._read_timeout_handle is not None
proto.eof_received()
assert proto._read_timeout_handle is None


async def test_connection_lost_sets_transport_to_none(loop: Any, mocker: Any) -> None:
"""Ensure that the transport is set to None when the connection is lost.
This ensures the writer knows that the connection is closed.
"""
proto = ResponseHandler(loop=loop)
proto.connection_made(mocker.Mock())
assert proto.transport is not None

proto.connection_lost(OSError())

assert proto.transport is None
17 changes: 17 additions & 0 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,23 @@ async def test_write_to_closing_transport(protocol, transport, loop) -> None:
await msg.write(b"After closing")


async def test_write_to_closed_transport(
protocol: Any, transport: Any, loop: Any
) -> None:
"""Test that writing to a closed transport raises ConnectionResetError.
The StreamWriter checks to see if protocol.transport is None before
writing to the transport. If it is None, it raises ConnectionResetError.
"""
msg = http.StreamWriter(protocol, loop)

await msg.write(b"Before transport close")
protocol.transport = None

with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"):
await msg.write(b"After transport closed")


async def test_drain(protocol, transport, loop) -> None:
msg = http.StreamWriter(protocol, loop)
await msg.drain()
Expand Down

0 comments on commit b7d33c9

Please sign in to comment.