From 8d31f0badf665719e9faaf0132b47d20c01385b2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 26 Jun 2023 10:19:18 +0100 Subject: [PATCH] Graceful handling for HTTP/2 GoAway frames. (#733) * Refactor HTTP/2 stream events * Refactor HTTP/2 stream events * Graceful handling of HTTP/2 GoAway frames * Add to CHANGELOG * Remove unneccessary getattr * Conditional fix * Conditional fix --- CHANGELOG.md | 1 + httpcore/_async/http2.py | 71 +++++++++++----- httpcore/_sync/http2.py | 71 +++++++++++----- tests/_async/test_connection_pool.py | 119 +++++++++++++++++++++++++++ tests/_async/test_http2.py | 12 ++- tests/_sync/test_connection_pool.py | 119 +++++++++++++++++++++++++++ tests/_sync/test_http2.py | 12 ++- 7 files changed, 357 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5f8fdd7..911371d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## unreleased - The networking backend interface has [been added to the public API](https://www.encode.io/httpcore/network-backends). Some classes which were previously private implementation detail are now part of the top-level public API. (#699) +- Graceful handling of HTTP/2 GoAway frames, with requests being transparently retried on a new connection. (#730) - Add exceptions when a synchronous `trace callback` is passed to an asynchronous request or an asynchronous `trace callback` is passed to a synchronous request. (#717) ## 0.17.2 (May 23th, 2023) diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index af7c5b8c..cc957601 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -61,10 +61,26 @@ def __init__( self._sent_connection_init = False self._used_all_stream_ids = False self._connection_error = False - self._events: typing.Dict[int, h2.events.Event] = {} + + # Mapping from stream ID to response stream events. + self._events: typing.Dict[ + int, + typing.Union[ + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ], + ] = {} + + # Connection terminated events are stored as state since + # we need to handle them for all streams. + self._connection_terminated: typing.Optional[ + h2.events.ConnectionTerminated + ] = None + self._read_exception: typing.Optional[Exception] = None self._write_exception: typing.Optional[Exception] = None - self._connection_error_event: typing.Optional[h2.events.Event] = None async def handle_async_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): @@ -111,6 +127,7 @@ async def handle_async_request(self, request: Request) -> Response: self._events[stream_id] = [] except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover self._used_all_stream_ids = True + self._request_count -= 1 raise ConnectionNotAvailable() try: @@ -152,8 +169,8 @@ async def handle_async_request(self, request: Request) -> Response: # # In this case we'll have stored the event, and should raise # it as a RemoteProtocolError. - if self._connection_error_event: - raise RemoteProtocolError(self._connection_error_event) + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) # If h2 raises a protocol error in some other state then we # must somehow have made a protocol violation. raise LocalProtocolError(exc) # pragma: nocover @@ -292,12 +309,14 @@ async def _receive_response_body( self._h2_state.acknowledge_received_data(amount, stream_id) await self._write_outgoing_data(request) yield event.data - elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)): + elif isinstance(event, h2.events.StreamEnded): break async def _receive_stream_event( self, request: Request, stream_id: int - ) -> h2.events.Event: + ) -> typing.Union[ + h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded + ]: """ Return the next available event for a given stream ID. @@ -306,8 +325,7 @@ async def _receive_stream_event( while not self._events.get(stream_id): await self._receive_events(request, stream_id) event = self._events[stream_id].pop(0) - # The StreamReset event applies to a single stream. - if hasattr(event, "error_code"): + if isinstance(event, h2.events.StreamReset): raise RemoteProtocolError(event) return event @@ -319,8 +337,12 @@ async def _receive_events( for a given stream ID. """ async with self._read_lock: - if self._connection_error_event is not None: # pragma: nocover - raise RemoteProtocolError(self._connection_error_event) + if self._connection_terminated is not None: + last_stream_id = self._connection_terminated.last_stream_id + if stream_id and last_stream_id and stream_id > last_stream_id: + self._request_count -= 1 + raise ConnectionNotAvailable() + raise RemoteProtocolError(self._connection_terminated) # This conditional is a bit icky. We don't want to block reading if we've # actually got an event to return for a given stream. We need to do that @@ -338,16 +360,20 @@ async def _receive_events( await self._receive_remote_settings_change(event) trace.return_value = event - event_stream_id = getattr(event, "stream_id", 0) - - # The ConnectionTerminatedEvent applies to the entire connection, - # and should be saved so it can be raised on all streams. - if hasattr(event, "error_code") and event_stream_id == 0: - self._connection_error_event = event - raise RemoteProtocolError(event) - - if event_stream_id in self._events: - self._events[event_stream_id].append(event) + elif isinstance( + event, + ( + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ), + ): + if event.stream_id in self._events: + self._events[event.stream_id].append(event) + + elif isinstance(event, h2.events.ConnectionTerminated): + self._connection_terminated = event await self._write_outgoing_data(request) @@ -372,7 +398,10 @@ async def _response_closed(self, stream_id: int) -> None: await self._max_streams_semaphore.release() del self._events[stream_id] async with self._state_lock: - if self._state == HTTPConnectionState.ACTIVE and not self._events: + if self._connection_terminated and not self._events: + await self.aclose() + + elif self._state == HTTPConnectionState.ACTIVE and not self._events: self._state = HTTPConnectionState.IDLE if self._keepalive_expiry is not None: now = time.monotonic() diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 2e5fe4a3..fbbc67bf 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -61,10 +61,26 @@ def __init__( self._sent_connection_init = False self._used_all_stream_ids = False self._connection_error = False - self._events: typing.Dict[int, h2.events.Event] = {} + + # Mapping from stream ID to response stream events. + self._events: typing.Dict[ + int, + typing.Union[ + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ], + ] = {} + + # Connection terminated events are stored as state since + # we need to handle them for all streams. + self._connection_terminated: typing.Optional[ + h2.events.ConnectionTerminated + ] = None + self._read_exception: typing.Optional[Exception] = None self._write_exception: typing.Optional[Exception] = None - self._connection_error_event: typing.Optional[h2.events.Event] = None def handle_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): @@ -111,6 +127,7 @@ def handle_request(self, request: Request) -> Response: self._events[stream_id] = [] except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover self._used_all_stream_ids = True + self._request_count -= 1 raise ConnectionNotAvailable() try: @@ -152,8 +169,8 @@ def handle_request(self, request: Request) -> Response: # # In this case we'll have stored the event, and should raise # it as a RemoteProtocolError. - if self._connection_error_event: - raise RemoteProtocolError(self._connection_error_event) + if self._connection_terminated: # pragma: nocover + raise RemoteProtocolError(self._connection_terminated) # If h2 raises a protocol error in some other state then we # must somehow have made a protocol violation. raise LocalProtocolError(exc) # pragma: nocover @@ -292,12 +309,14 @@ def _receive_response_body( self._h2_state.acknowledge_received_data(amount, stream_id) self._write_outgoing_data(request) yield event.data - elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)): + elif isinstance(event, h2.events.StreamEnded): break def _receive_stream_event( self, request: Request, stream_id: int - ) -> h2.events.Event: + ) -> typing.Union[ + h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded + ]: """ Return the next available event for a given stream ID. @@ -306,8 +325,7 @@ def _receive_stream_event( while not self._events.get(stream_id): self._receive_events(request, stream_id) event = self._events[stream_id].pop(0) - # The StreamReset event applies to a single stream. - if hasattr(event, "error_code"): + if isinstance(event, h2.events.StreamReset): raise RemoteProtocolError(event) return event @@ -319,8 +337,12 @@ def _receive_events( for a given stream ID. """ with self._read_lock: - if self._connection_error_event is not None: # pragma: nocover - raise RemoteProtocolError(self._connection_error_event) + if self._connection_terminated is not None: + last_stream_id = self._connection_terminated.last_stream_id + if stream_id and last_stream_id and stream_id > last_stream_id: + self._request_count -= 1 + raise ConnectionNotAvailable() + raise RemoteProtocolError(self._connection_terminated) # This conditional is a bit icky. We don't want to block reading if we've # actually got an event to return for a given stream. We need to do that @@ -338,16 +360,20 @@ def _receive_events( self._receive_remote_settings_change(event) trace.return_value = event - event_stream_id = getattr(event, "stream_id", 0) - - # The ConnectionTerminatedEvent applies to the entire connection, - # and should be saved so it can be raised on all streams. - if hasattr(event, "error_code") and event_stream_id == 0: - self._connection_error_event = event - raise RemoteProtocolError(event) - - if event_stream_id in self._events: - self._events[event_stream_id].append(event) + elif isinstance( + event, + ( + h2.events.ResponseReceived, + h2.events.DataReceived, + h2.events.StreamEnded, + h2.events.StreamReset, + ), + ): + if event.stream_id in self._events: + self._events[event.stream_id].append(event) + + elif isinstance(event, h2.events.ConnectionTerminated): + self._connection_terminated = event self._write_outgoing_data(request) @@ -372,7 +398,10 @@ def _response_closed(self, stream_id: int) -> None: self._max_streams_semaphore.release() del self._events[stream_id] with self._state_lock: - if self._state == HTTPConnectionState.ACTIVE and not self._events: + if self._connection_terminated and not self._events: + self.close() + + elif self._state == HTTPConnectionState.ACTIVE and not self._events: self._state = HTTPConnectionState.IDLE if self._keepalive_expiry is not None: now = time.monotonic() diff --git a/tests/_async/test_connection_pool.py b/tests/_async/test_connection_pool.py index df0199ab..2392ca17 100644 --- a/tests/_async/test_connection_pool.py +++ b/tests/_async/test_connection_pool.py @@ -1,6 +1,8 @@ import logging import typing +import hpack +import hyperframe.frame import pytest import trio as concurrency @@ -111,6 +113,123 @@ async def test_connection_pool_with_close(): assert info == [] +@pytest.mark.anyio +async def test_connection_pool_with_http2(): + """ + Test a connection pool with HTTP/2 requests. + """ + network_backend = httpcore.AsyncMockBackend( + buffer=[ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + hyperframe.frame.HeadersFrame( + stream_id=3, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=3, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ], + http2=True, + ) + + async with httpcore.AsyncConnectionPool( + network_backend=network_backend, + ) as pool: + # Sending an intial request, which once complete will return to the pool, IDLE. + response = await pool.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + # Sending a second request to the same origin will reuse the existing IDLE connection. + response = await pool.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + +@pytest.mark.anyio +async def test_connection_pool_with_http2_goaway(): + """ + Test a connection pool with HTTP/2 requests, that cleanly disconnects + with a GoAway frame after the first request. + """ + network_backend = httpcore.AsyncMockBackend( + buffer=[ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + hyperframe.frame.GoAwayFrame( + stream_id=0, error_code=0, last_stream_id=1 + ).serialize(), + b"", + ], + http2=True, + ) + + async with httpcore.AsyncConnectionPool( + network_backend=network_backend, + ) as pool: + # Sending an intial request, which once complete will return to the pool, IDLE. + response = await pool.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + # Sending a second request to the same origin will require a new connection. + response = await pool.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + info = [repr(c) for c in pool.connections] + assert info == [ + "", + "", + ] + + @pytest.mark.anyio async def test_trace_request(): """ diff --git a/tests/_async/test_http2.py b/tests/_async/test_http2.py index 59ba158e..b4ec6648 100644 --- a/tests/_async/test_http2.py +++ b/tests/_async/test_http2.py @@ -66,7 +66,9 @@ async def test_http2_connection_closed(): stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] ).serialize(), # Connection is closed after the first response - hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(), + hyperframe.frame.GoAwayFrame( + stream_id=0, error_code=0, last_stream_id=1 + ).serialize(), ] ) async with httpcore.AsyncHTTP2Connection( @@ -74,7 +76,7 @@ async def test_http2_connection_closed(): ) as conn: await conn.request("GET", "https://example.com/") - with pytest.raises(httpcore.RemoteProtocolError): + with pytest.raises(httpcore.ConnectionNotAvailable): await conn.request("GET", "https://example.com/") assert not conn.is_available() @@ -211,9 +213,13 @@ async def test_http2_connection_with_goaway(): ] ) async with httpcore.AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + # The initial request has been closed midway, with an unrecoverable error. with pytest.raises(httpcore.RemoteProtocolError): await conn.request("GET", "https://example.com/") - with pytest.raises(httpcore.RemoteProtocolError): + + # The second request can receive a graceful `ConnectionNotAvailable`, + # and may be retried on a new connection. + with pytest.raises(httpcore.ConnectionNotAvailable): await conn.request("GET", "https://example.com/") diff --git a/tests/_sync/test_connection_pool.py b/tests/_sync/test_connection_pool.py index aafa68aa..287c2bcc 100644 --- a/tests/_sync/test_connection_pool.py +++ b/tests/_sync/test_connection_pool.py @@ -1,6 +1,8 @@ import logging import typing +import hpack +import hyperframe.frame import pytest from tests import concurrency @@ -112,6 +114,123 @@ def test_connection_pool_with_close(): +def test_connection_pool_with_http2(): + """ + Test a connection pool with HTTP/2 requests. + """ + network_backend = httpcore.MockBackend( + buffer=[ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + hyperframe.frame.HeadersFrame( + stream_id=3, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=3, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ], + http2=True, + ) + + with httpcore.ConnectionPool( + network_backend=network_backend, + ) as pool: + # Sending an intial request, which once complete will return to the pool, IDLE. + response = pool.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + # Sending a second request to the same origin will reuse the existing IDLE connection. + response = pool.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + + +def test_connection_pool_with_http2_goaway(): + """ + Test a connection pool with HTTP/2 requests, that cleanly disconnects + with a GoAway frame after the first request. + """ + network_backend = httpcore.MockBackend( + buffer=[ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + hyperframe.frame.GoAwayFrame( + stream_id=0, error_code=0, last_stream_id=1 + ).serialize(), + b"", + ], + http2=True, + ) + + with httpcore.ConnectionPool( + network_backend=network_backend, + ) as pool: + # Sending an intial request, which once complete will return to the pool, IDLE. + response = pool.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + # Sending a second request to the same origin will require a new connection. + response = pool.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + info = [repr(c) for c in pool.connections] + assert info == [ + "", + "", + ] + + + def test_trace_request(): """ The 'trace' request extension allows for a callback function to inspect the diff --git a/tests/_sync/test_http2.py b/tests/_sync/test_http2.py index 2cb353e1..695359bd 100644 --- a/tests/_sync/test_http2.py +++ b/tests/_sync/test_http2.py @@ -66,7 +66,9 @@ def test_http2_connection_closed(): stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] ).serialize(), # Connection is closed after the first response - hyperframe.frame.GoAwayFrame(stream_id=0, error_code=0).serialize(), + hyperframe.frame.GoAwayFrame( + stream_id=0, error_code=0, last_stream_id=1 + ).serialize(), ] ) with httpcore.HTTP2Connection( @@ -74,7 +76,7 @@ def test_http2_connection_closed(): ) as conn: conn.request("GET", "https://example.com/") - with pytest.raises(httpcore.RemoteProtocolError): + with pytest.raises(httpcore.ConnectionNotAvailable): conn.request("GET", "https://example.com/") assert not conn.is_available() @@ -211,9 +213,13 @@ def test_http2_connection_with_goaway(): ] ) with httpcore.HTTP2Connection(origin=origin, stream=stream) as conn: + # The initial request has been closed midway, with an unrecoverable error. with pytest.raises(httpcore.RemoteProtocolError): conn.request("GET", "https://example.com/") - with pytest.raises(httpcore.RemoteProtocolError): + + # The second request can receive a graceful `ConnectionNotAvailable`, + # and may be retried on a new connection. + with pytest.raises(httpcore.ConnectionNotAvailable): conn.request("GET", "https://example.com/")