Skip to content

Commit

Permalink
Docstrings and minor refactor for HTTP/2. (#720)
Browse files Browse the repository at this point in the history
* Add HTTP/2 data_stream extension

* Expose HTTP2DataStream in 'data_stream' extension.

* Remove data_stream extension

* Remove unrelated change of 'network_stream' extension
  • Loading branch information
tomchristie authored Jun 12, 2023
1 parent 7548dd5 commit eb5957d
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 14 deletions.
51 changes: 44 additions & 7 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ async def handle_async_request(self, request: Request) -> Response:
status=status,
headers=headers,
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={"stream_id": stream_id, "http_version": b"HTTP/2"},
extensions={
"http_version": b"HTTP/2",
"stream_id": stream_id,
},
)
except Exception as exc: # noqa: PIE786
kwargs = {"stream_id": stream_id}
Expand Down Expand Up @@ -190,6 +193,9 @@ async def _send_connection_init(self, request: Request) -> None:
# Sending the request...

async def _send_request_headers(self, request: Request, stream_id: int) -> None:
"""
Send the request headers to a given stream ID.
"""
end_stream = not has_body_headers(request)

# In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
Expand Down Expand Up @@ -218,18 +224,34 @@ async def _send_request_headers(self, request: Request, stream_id: int) -> None:
await self._write_outgoing_data(request)

async def _send_request_body(self, request: Request, stream_id: int) -> None:
"""
Iterate over the request body sending it to a given stream ID.
"""
if not has_body_headers(request):
return

assert isinstance(request.stream, typing.AsyncIterable)
async for data in request.stream:
while data:
max_flow = await self._wait_for_outgoing_flow(request, stream_id)
chunk_size = min(len(data), max_flow)
chunk, data = data[:chunk_size], data[chunk_size:]
self._h2_state.send_data(stream_id, chunk)
await self._write_outgoing_data(request)
await self._send_stream_data(request, stream_id, data)
await self._send_end_stream(request, stream_id)

async def _send_stream_data(
self, request: Request, stream_id: int, data: bytes
) -> None:
"""
Send a single chunk of data in one or more data frames.
"""
while data:
max_flow = await self._wait_for_outgoing_flow(request, stream_id)
chunk_size = min(len(data), max_flow)
chunk, data = data[:chunk_size], data[chunk_size:]
self._h2_state.send_data(stream_id, chunk)
await self._write_outgoing_data(request)

async def _send_end_stream(self, request: Request, stream_id: int) -> None:
"""
Send an empty data frame on on a given stream ID with the END_STREAM flag set.
"""
self._h2_state.end_stream(stream_id)
await self._write_outgoing_data(request)

Expand All @@ -238,6 +260,9 @@ async def _send_request_body(self, request: Request, stream_id: int) -> None:
async def _receive_response(
self, request: Request, stream_id: int
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
"""
Return the response status code and headers for a given stream ID.
"""
while True:
event = await self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.ResponseReceived):
Expand All @@ -256,6 +281,9 @@ async def _receive_response(
async def _receive_response_body(
self, request: Request, stream_id: int
) -> typing.AsyncIterator[bytes]:
"""
Iterator that returns the bytes of the response body for a given stream ID.
"""
while True:
event = await self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.DataReceived):
Expand All @@ -269,6 +297,11 @@ async def _receive_response_body(
async def _receive_stream_event(
self, request: Request, stream_id: int
) -> h2.events.Event:
"""
Return the next available event for a given stream ID.
Will read more data from the network if required.
"""
while not self._events.get(stream_id):
await self._receive_events(request, stream_id)
event = self._events[stream_id].pop(0)
Expand All @@ -280,6 +313,10 @@ async def _receive_stream_event(
async def _receive_events(
self, request: Request, stream_id: typing.Optional[int] = None
) -> None:
"""
Read some data from the network until we see one or more events
for a given stream ID.
"""
async with self._read_lock:
if self._connection_error_event is not None: # pragma: nocover
raise RemoteProtocolError(self._connection_error_event)
Expand Down
51 changes: 44 additions & 7 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def handle_request(self, request: Request) -> Response:
status=status,
headers=headers,
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
extensions={"stream_id": stream_id, "http_version": b"HTTP/2"},
extensions={
"http_version": b"HTTP/2",
"stream_id": stream_id,
},
)
except Exception as exc: # noqa: PIE786
kwargs = {"stream_id": stream_id}
Expand Down Expand Up @@ -190,6 +193,9 @@ def _send_connection_init(self, request: Request) -> None:
# Sending the request...

def _send_request_headers(self, request: Request, stream_id: int) -> None:
"""
Send the request headers to a given stream ID.
"""
end_stream = not has_body_headers(request)

# In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
Expand Down Expand Up @@ -218,18 +224,34 @@ def _send_request_headers(self, request: Request, stream_id: int) -> None:
self._write_outgoing_data(request)

def _send_request_body(self, request: Request, stream_id: int) -> None:
"""
Iterate over the request body sending it to a given stream ID.
"""
if not has_body_headers(request):
return

assert isinstance(request.stream, typing.Iterable)
for data in request.stream:
while data:
max_flow = self._wait_for_outgoing_flow(request, stream_id)
chunk_size = min(len(data), max_flow)
chunk, data = data[:chunk_size], data[chunk_size:]
self._h2_state.send_data(stream_id, chunk)
self._write_outgoing_data(request)
self._send_stream_data(request, stream_id, data)
self._send_end_stream(request, stream_id)

def _send_stream_data(
self, request: Request, stream_id: int, data: bytes
) -> None:
"""
Send a single chunk of data in one or more data frames.
"""
while data:
max_flow = self._wait_for_outgoing_flow(request, stream_id)
chunk_size = min(len(data), max_flow)
chunk, data = data[:chunk_size], data[chunk_size:]
self._h2_state.send_data(stream_id, chunk)
self._write_outgoing_data(request)

def _send_end_stream(self, request: Request, stream_id: int) -> None:
"""
Send an empty data frame on on a given stream ID with the END_STREAM flag set.
"""
self._h2_state.end_stream(stream_id)
self._write_outgoing_data(request)

Expand All @@ -238,6 +260,9 @@ def _send_request_body(self, request: Request, stream_id: int) -> None:
def _receive_response(
self, request: Request, stream_id: int
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
"""
Return the response status code and headers for a given stream ID.
"""
while True:
event = self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.ResponseReceived):
Expand All @@ -256,6 +281,9 @@ def _receive_response(
def _receive_response_body(
self, request: Request, stream_id: int
) -> typing.Iterator[bytes]:
"""
Iterator that returns the bytes of the response body for a given stream ID.
"""
while True:
event = self._receive_stream_event(request, stream_id)
if isinstance(event, h2.events.DataReceived):
Expand All @@ -269,6 +297,11 @@ def _receive_response_body(
def _receive_stream_event(
self, request: Request, stream_id: int
) -> h2.events.Event:
"""
Return the next available event for a given stream ID.
Will read more data from the network if required.
"""
while not self._events.get(stream_id):
self._receive_events(request, stream_id)
event = self._events[stream_id].pop(0)
Expand All @@ -280,6 +313,10 @@ def _receive_stream_event(
def _receive_events(
self, request: Request, stream_id: typing.Optional[int] = None
) -> None:
"""
Read some data from the network until we see one or more events
for a given stream ID.
"""
with self._read_lock:
if self._connection_error_event is not None: # pragma: nocover
raise RemoteProtocolError(self._connection_error_event)
Expand Down

0 comments on commit eb5957d

Please sign in to comment.