Skip to content

Commit

Permalink
[WIP] Extending the capabilities of the HTTP/2.0 server
Browse files Browse the repository at this point in the history
* Made multithreading more robust, and made it so each stream gets its
own thread

* Requests start getting parsed immediately, and do not wait for the
entire request to get there. This helps with 404 detection early on and
other use cases.

* Created H2Request object
  • Loading branch information
David Heiberg committed Jul 26, 2018
1 parent c54b451 commit f2654f3
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 33 deletions.
10 changes: 8 additions & 2 deletions tools/wptserve/wptserve/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,14 @@ def __init__(self, request_handler):

self.raw_input = InputFile(request_handler.rfile,
int(self.headers.get("Content-Length", 0)))

self._body = None

self._GET = None
self._POST = None
self._cookies = None
self._auth = None

self.h2_stream_id = request_handler.h2_stream_id if hasattr(request_handler, 'h2_stream_id') else None

self.server = Server(self)

def __repr__(self):
Expand Down Expand Up @@ -349,6 +348,13 @@ def auth(self):
return self._auth


class H2Request(Request):
def __init__(self, request_handler):
self.h2_stream_id = request_handler.h2_stream_id
self.frames = []
super(H2Request, self).__init__(request_handler)


class RequestHeaders(dict):
"""Dictionary-like API for accessing request headers."""
def __init__(self, items):
Expand Down
24 changes: 18 additions & 6 deletions tools/wptserve/wptserve/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def __init__(self, handler, response):
self.request = response.request
self.logger = response.logger

def write_headers(self, headers, status_code, status_message=None):
def write_headers(self, headers, status_code, status_message=None, stream_id=None, last=False):
formatted_headers = []
secondary_headers = [] # Non ':' prefixed headers are to be added afterwards

Expand All @@ -403,13 +403,14 @@ def write_headers(self, headers, status_code, status_message=None):

with self.h2conn as connection:
connection.send_headers(
stream_id=self.request.h2_stream_id,
stream_id=self.request.h2_stream_id if stream_id is None else stream_id,
headers=formatted_headers,
end_stream=last or self.request.method == "HEAD"
)

self.write(connection)

def write_content(self, item, last=False):
def write_content(self, item, last=False, stream_id=None):
if isinstance(item, (text_type, binary_type)):
data = BytesIO(self.encode(item))
else:
Expand All @@ -427,18 +428,29 @@ def write_content(self, item, last=False):
data_len -= payload_size
payload_size = self.get_max_payload_size()

self.write_content_frame(data.read(), last)
self.write_content_frame(data.read(), last, stream_id)

def write_content_frame(self, data, last):
def write_content_frame(self, data, last, stream_id=None):
with self.h2conn as connection:
connection.send_data(
stream_id=self.request.h2_stream_id,
stream_id=self.request.h2_stream_id if stream_id is None else stream_id,
data=data,
end_stream=last,
)
self.write(connection)
self.content_written = last

def write_push(self, promise_headers, response_headers, status, data=None, stream_id=None):
with self.h2conn as connection:
connection.push_stream(self.request.h2_stream_id, stream_id, promise_headers)
self.write(connection)

has_data = data is not None
self.write_headers(response_headers, status, stream_id=stream_id, last=not has_data)

if has_data:
self.write_content(data, last=True, stream_id=stream_id)

def get_max_payload_size(self):
with self.h2conn as connection:
return min(connection.remote_settings.max_frame_size, connection.local_flow_control_window(self.request.h2_stream_id)) - 9
Expand Down
125 changes: 100 additions & 25 deletions tools/wptserve/wptserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
import uuid
from collections import OrderedDict

if sys.version[0] == '2':
from Queue import Queue
else:
from queue import Queue
from copy import copy

from h2.config import H2Configuration
from h2.connection import H2Connection
from h2.events import RequestReceived, ConnectionTerminated
from h2.events import RequestReceived, ConnectionTerminated, DataReceived, StreamEnded, StreamReset

from six.moves.urllib.parse import urlsplit, urlunsplit

from . import routes as default_routes
from .config import ConfigBuilder
from .logger import get_logger
from .request import Server, Request
from .request import Server, Request, H2Request
from .response import Response, H2Response
from .router import Router
from .utils import HTTPException
Expand Down Expand Up @@ -220,11 +226,12 @@ def __init__(self, *args, **kwargs):
self.logger = get_logger()
BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args, **kwargs)

def finish_handling(self, request_line_is_valid, response_cls):
self.server.rewriter.rewrite(self)
def finish_handling(self, request_line_is_valid, response_cls, request=None, response=None):
if request is None or response is None:
self.server.rewriter.rewrite(self)

request = Request(self)
response = response_cls(self, request)
request = Request(self)
response = response_cls(self, request)

if request.method == "CONNECT":
self.handle_connect(response)
Expand Down Expand Up @@ -254,6 +261,7 @@ def finish_handling(self, request_line_is_valid, response_cls):
time.sleep(latency / 1000.)

if handler is None:
self.logger.debug("No Handler found!")
response.set_error(404)
else:
try:
Expand Down Expand Up @@ -322,9 +330,9 @@ def handle_one_request(self):
self.close_connection = False

# Generate a UUID to make it easier to distinguish different H2 connection debug messages
uid = uuid.uuid4()
self.uid = str(uuid.uuid4())[:8]

self.logger.debug('(%s) Initiating h2 Connection' % uid)
self.logger.debug('(%s) Initiating h2 Connection' % self.uid)

with self.conn as connection:
connection.initiate_connection()
Expand All @@ -337,48 +345,115 @@ def handle_one_request(self):
# TODO Need to do some major work on multithreading. Current idea is to have a thread per stream
# so that processing of the request can start from the first frame.

self.stream_queues = {}

while not self.close_connection:
try:
# This size may need to be made variable based on remote settings?
data = self.request.recv(65535)

with self.conn as connection:
events = connection.receive_data(data)

self.logger.debug('(%s) Events: ' % (uid) + str(events))

for event in events:
if isinstance(event, RequestReceived):
self.logger.debug('(%s) Parsing RequestReceived' % (uid))
self._h2_parse_request(event)
t = threading.Thread(target=BaseWebTestRequestHandler.finish_handling, args=(self, True, H2Response))
self.request_threads.append(t)
t.start()
if isinstance(event, ConnectionTerminated):
self.logger.debug('(%s) Connection terminated by remote peer ' % (uid))
frames = connection.receive_data(data)

self.logger.debug('(%s) Frames Received: ' % (self.uid) + str(frames))

for frame in frames:
if isinstance(frame, (RequestReceived, DataReceived, StreamEnded, StreamReset)):
if frame.stream_id not in self.stream_queues:
self.stream_queues[frame.stream_id] = Queue()
self.start_stream_thread(frame)
self.stream_queues[frame.stream_id].put(frame)
elif isinstance(frame, ConnectionTerminated):
self.logger.debug('(%s) Connection terminated by remote peer ' % (self.uid))
self.close_connection = True

# Flood all the streams with connection terminated, this will cause them to stop
for stream_id, queue in self.stream_queues.items():
queue.put(frame)

except (socket.timeout, socket.error) as e:
self.logger.debug('(%s) ERROR - Closing Connection - \n%s' % (uid, str(e)))
self.logger.debug('(%s) ERROR - Closing Connection - \n%s' % (self.uid, str(e)))
self.close_connection = True
for t in self.request_threads:
t.join()

def _h2_parse_request(self, event):
self.headers = H2Headers(event.headers)
def start_stream_thread(self, frame):
t = threading.Thread(
target=Http2WebTestRequestHandler._stream_thread,
args=(self, frame.stream_id)
)
self.request_threads.append(t)
t.start()

def _stream_thread(self, stream_id):
"""
This thread processes frames for a specific stream. It waits for frames to be placed
in the queue, and processes them. When it receives a request frame, it will start processing
immediately, even if there are data frames to follow. One of the reasons for this is that it
can detect invalid requests before needing to read the rest of the frames.
"""

# The file-like pipe object that will be used to share data to request object if data is received
wfile = None
request = None
while not self.close_connection:
# Wait for next frame, blocking
frame = self.stream_queues[stream_id].get(True, None)

self.logger.debug('(%s - %s) %s' % (self.uid, stream_id, str(frame)))

if isinstance(frame, RequestReceived):
# Create a shallow copy of this instance, to avoid issues with multiple threads editing the same object
self_copy = copy(self)
self_copy._h2_parse_request(frame)

# If its a POST request, we expect data frames and open the pipe to the request object
if self_copy.command == 'POST':
rfile, wfile = os.pipe()
rfile, wfile = os.fdopen(rfile, 'rb'), os.fdopen(wfile, 'wb')
self_copy.rfile = rfile

self_copy.server.rewriter.rewrite(self_copy)
request = H2Request(self_copy)
request.frames.append(frame)
response = H2Response(self_copy, request)

# Begin processing the response in another thread,
# and continue listening for more data here in this one
t = threading.Thread(
target=BaseWebTestRequestHandler.finish_handling,
args=(self, True, H2Response,),
kwargs={'request':request, 'response':response}
)
t.start()
elif isinstance(frame, DataReceived):
request.frames.append(frame)
wfile.write(frame.data)
if frame.stream_ended:
wfile.close()
elif isinstance(frame, (StreamReset, ConnectionTerminated)):
assert self.stream_queues[stream_id].empty()
del self.stream_queues[stream_id]
self.logger.debug('(%s - %s) Stream Reset, Thread Closing' % (self.uid, stream_id))
break

def _h2_parse_request(self, frame):
self.headers = H2Headers(frame.headers)
self.command = self.headers['method']
self.path = self.headers['path']
self.h2_stream_id = event.stream_id
self.h2_stream_id = frame.stream_id

# TODO Need to figure out what to do with this thing as it is no longer used
# For now I can just leave it be as it does not affect anything
self.raw_requestline = ''


class H2ConnectionGuard(object):
"""H2Connection objects are not threadsafe, so this keeps thread safety"""
lock = threading.Lock()

def __init__(self, obj):
assert isinstance(obj, H2Connection)
self.obj = obj

def __enter__(self):
Expand Down

0 comments on commit f2654f3

Please sign in to comment.