From b3c80ee3f7d5d8f0b8bc27afe52e4d46621eaf99 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 31 Jan 2017 13:56:58 -0800 Subject: [PATCH] Accepting back-pressure from slow websocket clients #1367 --- CHANGES.rst | 2 ++ aiohttp/_ws_impl.py | 31 +++++++++++++++++++++++-------- aiohttp/web_ws.py | 1 + docs/web_reference.rst | 10 ++++++++++ tests/test_web_websocket.py | 6 +++++- 5 files changed, 41 insertions(+), 9 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index a9a24f60764..95990abec0f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -17,6 +17,8 @@ CHANGES - Remove `web.Application` dependency from `web.UrlDispatcher` #1510 +- Accepting back-pressure from slow websocket clients #1367 + - Do not pause transport during set_parser stage #1211 - Lingering close doesn't terminate before timeout #1559 diff --git a/aiohttp/_ws_impl.py b/aiohttp/_ws_impl.py index 27ed5693bd3..f695e5f9169 100644 --- a/aiohttp/_ws_impl.py +++ b/aiohttp/_ws_impl.py @@ -65,6 +65,7 @@ class WSMsgType(IntEnum): PACK_LEN3 = Struct('!BBQ').pack PACK_CLOSE_CODE = Struct('!H').pack MSG_SIZE = 2 ** 14 +DEFAULT_LIMIT = 2 ** 16 _WSMessageBase = collections.namedtuple('_WSMessageBase', @@ -299,10 +300,13 @@ def parse_frame(buf, continuation=False): class WebSocketWriter: - def __init__(self, writer, *, use_mask=False, random=random.Random()): + def __init__(self, writer, *, + use_mask=False, limit=DEFAULT_LIMIT, random=random.Random()): self.writer = writer self.use_mask = use_mask self.randrange = random.randrange + self._limit = limit + self._output_size = 0 def _send_frame(self, message, opcode): """Send a frame over the websocket with message as its payload.""" @@ -325,6 +329,7 @@ def _send_frame(self, message, opcode): mask = mask.to_bytes(4, 'big') message = _websocket_mask(mask, bytearray(message)) self.writer.write(header + mask + message) + self._output_size += len(header) + len(mask) + len(message) else: if len(message) > MSG_SIZE: self.writer.write(header) @@ -332,36 +337,45 @@ def _send_frame(self, message, opcode): else: self.writer.write(header + message) + self._output_size += len(header) + len(message) + + if self._output_size > self._limit: + self._output_size = 0 + return self.writer.drain() + + return () + def pong(self, message=b''): """Send pong message.""" if isinstance(message, str): message = message.encode('utf-8') - self._send_frame(message, WSMsgType.PONG) + return self._send_frame(message, WSMsgType.PONG) def ping(self, message=b''): """Send ping message.""" if isinstance(message, str): message = message.encode('utf-8') - self._send_frame(message, WSMsgType.PING) + return self._send_frame(message, WSMsgType.PING) def send(self, message, binary=False): """Send a frame over the websocket with message as its payload.""" if isinstance(message, str): message = message.encode('utf-8') if binary: - self._send_frame(message, WSMsgType.BINARY) + return self._send_frame(message, WSMsgType.BINARY) else: - self._send_frame(message, WSMsgType.TEXT) + return self._send_frame(message, WSMsgType.TEXT) def close(self, code=1000, message=b''): """Close the websocket, sending the specified code and message.""" if isinstance(message, str): message = message.encode('utf-8') - self._send_frame( + return self._send_frame( PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE) -def do_handshake(method, headers, transport, protocols=()): +def do_handshake(method, headers, transport, + protocols=(), write_buffer_size=DEFAULT_LIMIT): """Prepare WebSocket handshake. It return HTTP response code, response headers, websocket parser, @@ -371,6 +385,7 @@ def do_handshake(method, headers, transport, protocols=()): the returned response headers contain the first protocol in this list which the server also knows. + `write_buffer_size` max size of write buffer before `drain()` get called. """ # WebSocket accepts only GET if method.upper() != hdrs.METH_GET: @@ -434,5 +449,5 @@ def do_handshake(method, headers, transport, protocols=()): return (101, response_headers, WebSocketParser, - WebSocketWriter(transport), + WebSocketWriter(transport, limit=write_buffer_size), protocol) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 596a85aa7d9..b81960bfe23 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -187,6 +187,7 @@ def close(self, *, code=1000, message=b''): self._closed = True try: self._writer.close(code, message) + yield from self.drain() except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = 1006 raise diff --git a/docs/web_reference.rst b/docs/web_reference.rst index b55bfd8b4df..efba24ac9e0 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -1109,6 +1109,16 @@ WebSocketResponse .. seealso:: :ref:`WebSockets handling` +WebSocketResponse Send Flow Control +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To enable send flow control you need to treat methods +`ping()`, `pong()`, `send_str()`, `send_bytes()`, `send_json()` as coroutines. +By default write buffer size is set to 64k. + +.. versionadded:: 1.3.0 + + WebSocketReady ^^^^^^^^^^^^^^ diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 53e4844dc5e..33041d7ac6a 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -361,6 +361,8 @@ def test_receive_exc_in_reader(make_request, loop, reader): res = helpers.create_future(loop) res.set_exception(exc) reader.read = make_mocked_coro(res) + ws._resp_impl.transport.drain.return_value = helpers.create_future(loop) + ws._resp_impl.transport.drain.return_value.set_result(True) msg = yield from ws.receive() assert msg.type == WSMsgType.ERROR @@ -444,7 +446,7 @@ def test_concurrent_receive(make_request): @asyncio.coroutine -def test_close_exc(make_request, reader, loop): +def test_close_exc(make_request, reader, loop, mocker): req = make_request('GET', '/') ws = WebSocketResponse() @@ -453,6 +455,8 @@ def test_close_exc(make_request, reader, loop): exc = ValueError() reader.read.return_value = helpers.create_future(loop) reader.read.return_value.set_exception(exc) + ws._resp_impl.transport.drain.return_value = helpers.create_future(loop) + ws._resp_impl.transport.drain.return_value.set_result(True) yield from ws.close() assert ws.closed