Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graceful handling for HTTP/2 GoAway frames. #733

Merged
merged 7 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## unreleased

- The networking backend interface has [been added to the public API](https://www.encode.io/httpcore/network-backends). Some classes which were previously private implementation detail are now part of the top-level public API. (#699)
- Graceful handling of HTTP/2 GoAway frames, with requests being transparently retried on a new connection. (#730)
- Add exceptions when a synchronous `trace callback` is passed to an asynchronous request or an asynchronous `trace callback` is passed to a synchronous request. (#717)

## 0.17.2 (May 23th, 2023)
Expand Down
72 changes: 51 additions & 21 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,26 @@ def __init__(
self._sent_connection_init = False
self._used_all_stream_ids = False
self._connection_error = False
self._events: typing.Dict[int, h2.events.Event] = {}

# Mapping from stream ID to response stream events.
self._events: typing.Dict[
int,
typing.Union[
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
],
] = {}

# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: typing.Optional[
h2.events.ConnectionTerminated
] = None

self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
self._connection_error_event: typing.Optional[h2.events.Event] = None

async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
Expand Down Expand Up @@ -111,6 +127,7 @@ async def handle_async_request(self, request: Request) -> Response:
self._events[stream_id] = []
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
self._used_all_stream_ids = True
self._request_count -= 1
raise ConnectionNotAvailable()

try:
Expand Down Expand Up @@ -152,8 +169,8 @@ async def handle_async_request(self, request: Request) -> Response:
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_error_event:
raise RemoteProtocolError(self._connection_error_event)
if self._connection_terminated: # pragma: nocover
raise RemoteProtocolError(self._connection_terminated)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
raise LocalProtocolError(exc) # pragma: nocover
Expand Down Expand Up @@ -292,12 +309,14 @@ async def _receive_response_body(
self._h2_state.acknowledge_received_data(amount, stream_id)
await self._write_outgoing_data(request)
yield event.data
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
elif isinstance(event, h2.events.StreamEnded):
break

async def _receive_stream_event(
self, request: Request, stream_id: int
) -> h2.events.Event:
) -> typing.Union[
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
]:
"""
Return the next available event for a given stream ID.

Expand All @@ -306,8 +325,7 @@ async def _receive_stream_event(
while not self._events.get(stream_id):
await self._receive_events(request, stream_id)
event = self._events[stream_id].pop(0)
# The StreamReset event applies to a single stream.
if hasattr(event, "error_code"):
if isinstance(event, h2.events.StreamReset):
raise RemoteProtocolError(event)
return event

Expand All @@ -319,8 +337,12 @@ async def _receive_events(
for a given stream ID.
"""
async with self._read_lock:
if self._connection_error_event is not None: # pragma: nocover
raise RemoteProtocolError(self._connection_error_event)
if self._connection_terminated is not None:
last_stream_id = self._connection_terminated.last_stream_id
if last_stream_id and stream_id > last_stream_id:
tomchristie marked this conversation as resolved.
Show resolved Hide resolved
self._request_count -= 1
raise ConnectionNotAvailable()
raise RemoteProtocolError(self._connection_terminated)

# This conditional is a bit icky. We don't want to block reading if we've
# actually got an event to return for a given stream. We need to do that
Expand All @@ -338,16 +360,21 @@ async def _receive_events(
await self._receive_remote_settings_change(event)
trace.return_value = event

event_stream_id = getattr(event, "stream_id", 0)

# The ConnectionTerminatedEvent applies to the entire connection,
# and should be saved so it can be raised on all streams.
if hasattr(event, "error_code") and event_stream_id == 0:
self._connection_error_event = event
raise RemoteProtocolError(event)

if event_stream_id in self._events:
self._events[event_stream_id].append(event)
elif isinstance(
event,
(
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
),
):
event_stream_id = getattr(event, "stream_id", 0)
if event_stream_id in self._events:
self._events[event_stream_id].append(event)

elif isinstance(event, h2.events.ConnectionTerminated):
self._connection_terminated = event

await self._write_outgoing_data(request)

Expand All @@ -372,7 +399,10 @@ async def _response_closed(self, stream_id: int) -> None:
await self._max_streams_semaphore.release()
del self._events[stream_id]
async with self._state_lock:
if self._state == HTTPConnectionState.ACTIVE and not self._events:
if self._connection_terminated and not self._events:
await self.aclose()

elif self._state == HTTPConnectionState.ACTIVE and not self._events:
self._state = HTTPConnectionState.IDLE
if self._keepalive_expiry is not None:
now = time.monotonic()
Expand Down
72 changes: 51 additions & 21 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,26 @@ def __init__(
self._sent_connection_init = False
self._used_all_stream_ids = False
self._connection_error = False
self._events: typing.Dict[int, h2.events.Event] = {}

# Mapping from stream ID to response stream events.
self._events: typing.Dict[
int,
typing.Union[
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
],
] = {}

# Connection terminated events are stored as state since
# we need to handle them for all streams.
self._connection_terminated: typing.Optional[
h2.events.ConnectionTerminated
] = None

self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
self._connection_error_event: typing.Optional[h2.events.Event] = None

def handle_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
Expand Down Expand Up @@ -111,6 +127,7 @@ def handle_request(self, request: Request) -> Response:
self._events[stream_id] = []
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
self._used_all_stream_ids = True
self._request_count -= 1
raise ConnectionNotAvailable()

try:
Expand Down Expand Up @@ -152,8 +169,8 @@ def handle_request(self, request: Request) -> Response:
#
# In this case we'll have stored the event, and should raise
# it as a RemoteProtocolError.
if self._connection_error_event:
raise RemoteProtocolError(self._connection_error_event)
if self._connection_terminated: # pragma: nocover
raise RemoteProtocolError(self._connection_terminated)
# If h2 raises a protocol error in some other state then we
# must somehow have made a protocol violation.
raise LocalProtocolError(exc) # pragma: nocover
Expand Down Expand Up @@ -292,12 +309,14 @@ def _receive_response_body(
self._h2_state.acknowledge_received_data(amount, stream_id)
self._write_outgoing_data(request)
yield event.data
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
elif isinstance(event, h2.events.StreamEnded):
break

def _receive_stream_event(
self, request: Request, stream_id: int
) -> h2.events.Event:
) -> typing.Union[
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
]:
"""
Return the next available event for a given stream ID.

Expand All @@ -306,8 +325,7 @@ def _receive_stream_event(
while not self._events.get(stream_id):
self._receive_events(request, stream_id)
event = self._events[stream_id].pop(0)
# The StreamReset event applies to a single stream.
if hasattr(event, "error_code"):
if isinstance(event, h2.events.StreamReset):
raise RemoteProtocolError(event)
return event

Expand All @@ -319,8 +337,12 @@ def _receive_events(
for a given stream ID.
"""
with self._read_lock:
if self._connection_error_event is not None: # pragma: nocover
raise RemoteProtocolError(self._connection_error_event)
if self._connection_terminated is not None:
last_stream_id = self._connection_terminated.last_stream_id
if last_stream_id and stream_id > last_stream_id:
self._request_count -= 1
raise ConnectionNotAvailable()
raise RemoteProtocolError(self._connection_terminated)

# This conditional is a bit icky. We don't want to block reading if we've
# actually got an event to return for a given stream. We need to do that
Expand All @@ -338,16 +360,21 @@ def _receive_events(
self._receive_remote_settings_change(event)
trace.return_value = event

event_stream_id = getattr(event, "stream_id", 0)

# The ConnectionTerminatedEvent applies to the entire connection,
# and should be saved so it can be raised on all streams.
if hasattr(event, "error_code") and event_stream_id == 0:
self._connection_error_event = event
raise RemoteProtocolError(event)

if event_stream_id in self._events:
self._events[event_stream_id].append(event)
elif isinstance(
event,
(
h2.events.ResponseReceived,
h2.events.DataReceived,
h2.events.StreamEnded,
h2.events.StreamReset,
),
):
event_stream_id = getattr(event, "stream_id", 0)
tomchristie marked this conversation as resolved.
Show resolved Hide resolved
if event_stream_id in self._events:
self._events[event_stream_id].append(event)

elif isinstance(event, h2.events.ConnectionTerminated):
self._connection_terminated = event

self._write_outgoing_data(request)

Expand All @@ -372,7 +399,10 @@ def _response_closed(self, stream_id: int) -> None:
self._max_streams_semaphore.release()
del self._events[stream_id]
with self._state_lock:
if self._state == HTTPConnectionState.ACTIVE and not self._events:
if self._connection_terminated and not self._events:
self.close()

elif self._state == HTTPConnectionState.ACTIVE and not self._events:
self._state = HTTPConnectionState.IDLE
if self._keepalive_expiry is not None:
now = time.monotonic()
Expand Down
Loading