Skip to content

Commit

Permalink
Graceful handling for HTTP/2 GoAway frames. (#733)
Browse files Browse the repository at this point in the history
* Refactor HTTP/2 stream events

* Refactor HTTP/2 stream events

* Graceful handling of HTTP/2 GoAway frames

* Add to CHANGELOG

* Remove unneccessary getattr

* Conditional fix

* Conditional fix
  • Loading branch information
tomchristie authored Jun 26, 2023
1 parent aacdbb9 commit 8d31f0b
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## unreleased

- The networking backend interface has [been added to the public API](https://www.encode.io/httpcore/network-backends). Some classes which were previously private implementation detail are now part of the top-level public API. (#699)
- Graceful handling of HTTP/2 GoAway frames, with requests being transparently retried on a new connection. (#730)
- Add exceptions when a synchronous `trace callback` is passed to an asynchronous request or an asynchronous `trace callback` is passed to a synchronous request. (#717)

## 0.17.2 (May 23th, 2023)
Expand Down
71 changes: 50 additions & 21 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,26 @@ def __init__(
self._sent_connection_init = False
self._used_all_stream_ids = False
self._connection_error = False
self._events: typing.Dict[int, h2.events.Event] = {}

# Mapping from stream ID to response stream events.
self._events: typing.Dict[
int,
typing.Union[
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
],
] = {}

# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: typing.Optional[
h2.events.ConnectionTerminated
] = None

self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
self._connection_error_event: typing.Optional[h2.events.Event] = None

async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
Expand Down Expand Up @@ -111,6 +127,7 @@ async def handle_async_request(self, request: Request) -> Response:
self._events[stream_id] = []
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
self._used_all_stream_ids = True
self._request_count -= 1
raise ConnectionNotAvailable()

try:
Expand Down Expand Up @@ -152,8 +169,8 @@ async def handle_async_request(self, request: Request) -> Response:
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_error_event:
raise RemoteProtocolError(self._connection_error_event)
if self._connection_terminated: # pragma: nocover
raise RemoteProtocolError(self._connection_terminated)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
raise LocalProtocolError(exc) # pragma: nocover
Expand Down Expand Up @@ -292,12 +309,14 @@ async def _receive_response_body(
self._h2_state.acknowledge_received_data(amount, stream_id)
await self._write_outgoing_data(request)
yield event.data
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
elif isinstance(event, h2.events.StreamEnded):
break

async def _receive_stream_event(
self, request: Request, stream_id: int
) -> h2.events.Event:
) -> typing.Union[
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
]:
"""
Return the next available event for a given stream ID.
Expand All @@ -306,8 +325,7 @@ async def _receive_stream_event(
while not self._events.get(stream_id):
await self._receive_events(request, stream_id)
event = self._events[stream_id].pop(0)
# The StreamReset event applies to a single stream.
if hasattr(event, "error_code"):
if isinstance(event, h2.events.StreamReset):
raise RemoteProtocolError(event)
return event

Expand All @@ -319,8 +337,12 @@ async def _receive_events(
for a given stream ID.
"""
async with self._read_lock:
if self._connection_error_event is not None: # pragma: nocover
raise RemoteProtocolError(self._connection_error_event)
if self._connection_terminated is not None:
last_stream_id = self._connection_terminated.last_stream_id
if stream_id and last_stream_id and stream_id > last_stream_id:
self._request_count -= 1
raise ConnectionNotAvailable()
raise RemoteProtocolError(self._connection_terminated)

# This conditional is a bit icky. We don't want to block reading if we've
# actually got an event to return for a given stream. We need to do that
Expand All @@ -338,16 +360,20 @@ async def _receive_events(
await self._receive_remote_settings_change(event)
trace.return_value = event

event_stream_id = getattr(event, "stream_id", 0)

# The ConnectionTerminatedEvent applies to the entire connection,
# and should be saved so it can be raised on all streams.
if hasattr(event, "error_code") and event_stream_id == 0:
self._connection_error_event = event
raise RemoteProtocolError(event)

if event_stream_id in self._events:
self._events[event_stream_id].append(event)
elif isinstance(
event,
(
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
),
):
if event.stream_id in self._events:
self._events[event.stream_id].append(event)

elif isinstance(event, h2.events.ConnectionTerminated):
self._connection_terminated = event

await self._write_outgoing_data(request)

Expand All @@ -372,7 +398,10 @@ async def _response_closed(self, stream_id: int) -> None:
await self._max_streams_semaphore.release()
del self._events[stream_id]
async with self._state_lock:
if self._state == HTTPConnectionState.ACTIVE and not self._events:
if self._connection_terminated and not self._events:
await self.aclose()

elif self._state == HTTPConnectionState.ACTIVE and not self._events:
self._state = HTTPConnectionState.IDLE
if self._keepalive_expiry is not None:
now = time.monotonic()
Expand Down
71 changes: 50 additions & 21 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,26 @@ def __init__(
self._sent_connection_init = False
self._used_all_stream_ids = False
self._connection_error = False
self._events: typing.Dict[int, h2.events.Event] = {}

# Mapping from stream ID to response stream events.
self._events: typing.Dict[
int,
typing.Union[
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
],
] = {}

# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: typing.Optional[
h2.events.ConnectionTerminated
] = None

self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
self._connection_error_event: typing.Optional[h2.events.Event] = None

def handle_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
Expand Down Expand Up @@ -111,6 +127,7 @@ def handle_request(self, request: Request) -> Response:
self._events[stream_id] = []
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
self._used_all_stream_ids = True
self._request_count -= 1
raise ConnectionNotAvailable()

try:
Expand Down Expand Up @@ -152,8 +169,8 @@ def handle_request(self, request: Request) -> Response:
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_error_event:
raise RemoteProtocolError(self._connection_error_event)
if self._connection_terminated: # pragma: nocover
raise RemoteProtocolError(self._connection_terminated)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
raise LocalProtocolError(exc) # pragma: nocover
Expand Down Expand Up @@ -292,12 +309,14 @@ def _receive_response_body(
self._h2_state.acknowledge_received_data(amount, stream_id)
self._write_outgoing_data(request)
yield event.data
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
elif isinstance(event, h2.events.StreamEnded):
break

def _receive_stream_event(
self, request: Request, stream_id: int
) -> h2.events.Event:
) -> typing.Union[
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
]:
"""
Return the next available event for a given stream ID.
Expand All @@ -306,8 +325,7 @@ def _receive_stream_event(
while not self._events.get(stream_id):
self._receive_events(request, stream_id)
event = self._events[stream_id].pop(0)
# The StreamReset event applies to a single stream.
if hasattr(event, "error_code"):
if isinstance(event, h2.events.StreamReset):
raise RemoteProtocolError(event)
return event

Expand All @@ -319,8 +337,12 @@ def _receive_events(
for a given stream ID.
"""
with self._read_lock:
if self._connection_error_event is not None: # pragma: nocover
raise RemoteProtocolError(self._connection_error_event)
if self._connection_terminated is not None:
last_stream_id = self._connection_terminated.last_stream_id
if stream_id and last_stream_id and stream_id > last_stream_id:
self._request_count -= 1
raise ConnectionNotAvailable()
raise RemoteProtocolError(self._connection_terminated)

# This conditional is a bit icky. We don't want to block reading if we've
# actually got an event to return for a given stream. We need to do that
Expand All @@ -338,16 +360,20 @@ def _receive_events(
self._receive_remote_settings_change(event)
trace.return_value = event

event_stream_id = getattr(event, "stream_id", 0)

# The ConnectionTerminatedEvent applies to the entire connection,
# and should be saved so it can be raised on all streams.
if hasattr(event, "error_code") and event_stream_id == 0:
self._connection_error_event = event
raise RemoteProtocolError(event)

if event_stream_id in self._events:
self._events[event_stream_id].append(event)
elif isinstance(
event,
(
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
),
):
if event.stream_id in self._events:
self._events[event.stream_id].append(event)

elif isinstance(event, h2.events.ConnectionTerminated):
self._connection_terminated = event

self._write_outgoing_data(request)

Expand All @@ -372,7 +398,10 @@ def _response_closed(self, stream_id: int) -> None:
self._max_streams_semaphore.release()
del self._events[stream_id]
with self._state_lock:
if self._state == HTTPConnectionState.ACTIVE and not self._events:
if self._connection_terminated and not self._events:
self.close()

elif self._state == HTTPConnectionState.ACTIVE and not self._events:
self._state = HTTPConnectionState.IDLE
if self._keepalive_expiry is not None:
now = time.monotonic()
Expand Down
Loading

0 comments on commit 8d31f0b

Please sign in to comment.