From a69ea902c7409739faefd3f4658c4589944a0ff4 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 31 Jan 2017 15:08:59 -0800 Subject: [PATCH] Added receive_timeout timeout for websocket to receive complete message. #1024 #1325 --- CHANGES.rst | 4 +- aiohttp/client.py | 19 +++++++-- aiohttp/client_ws.py | 23 ++++++----- aiohttp/helpers.py | 6 +-- aiohttp/test_utils.py | 8 ++++ aiohttp/web_ws.py | 25 +++++++----- docs/client_reference.rst | 6 ++- docs/web_reference.rst | 41 ++++++++++++------- tests/test_client_ws_functional.py | 50 +++++++++++++++++++++++ tests/test_web_websocket_functional.py | 56 ++++++++++++++++++++++++++ 10 files changed, 197 insertions(+), 41 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 95990abec0f..9487ee69f0f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,7 +4,7 @@ CHANGES 1.3.0 (XXXX-XX-XX) ------------------ -- separate read + connect + request timeouts # 1523 +- Separate read + connect + request timeouts # 1523 - Fix polls demo run application #1487 @@ -15,6 +15,8 @@ CHANGES - Do not use readline when reading the content of a part in the multipart reader #1535 +- Added `receive_timeout` timeout for websocket to receive complete message. #1024 #1325 + - Remove `web.Application` dependency from `web.UrlDispatcher` #1510 - Accepting back-pressure from slow websocket clients #1367 diff --git a/aiohttp/client.py b/aiohttp/client.py index d60929897dc..99bb2662f11 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -19,7 +19,7 @@ from .client_ws import ClientWebSocketResponse from .cookiejar import CookieJar from .errors import WSServerHandshakeError -from .helpers import Timeout +from .helpers import Timeout, TimeService __all__ = ('ClientSession', 'request', 'get', 'options', 'head', 'delete', 'post', 'put', 'patch', 'ws_connect') @@ -55,7 +55,7 @@ def __init__(self, *, connector=None, loop=None, cookies=None, response_class=ClientResponse, ws_response_class=ClientWebSocketResponse, version=aiohttp.HttpVersion11, - cookie_jar=None, read_timeout=None): + cookie_jar=None, read_timeout=None, time_service=None): if connector is None: connector = aiohttp.TCPConnector(loop=loop) @@ -107,6 +107,10 @@ def __init__(self, *, connector=None, loop=None, cookies=None, self._request_class = request_class self._response_class = response_class self._ws_response_class = ws_response_class + self._time_service = ( + time_service + if time_service is not None + else TimeService(self._loop)) def __del__(self, _warnings=warnings): if not self.closed: @@ -120,6 +124,10 @@ def __del__(self, _warnings=warnings): context['source_traceback'] = self._source_traceback self._loop.call_exception_handler(context) + @property + def time_service(self): + return self._time_service + def request(self, method, url, **kwargs): """Perform HTTP request.""" return _RequestContextManager(self._request(method, url, **kwargs)) @@ -278,6 +286,7 @@ def _request(self, method, url, *, def ws_connect(self, url, *, protocols=(), timeout=10.0, + receive_timeout=None, autoclose=True, autoping=True, auth=None, @@ -290,6 +299,7 @@ def ws_connect(self, url, *, self._ws_connect(url, protocols=protocols, timeout=timeout, + receive_timeout=receive_timeout, autoclose=autoclose, autoping=autoping, auth=auth, @@ -302,6 +312,7 @@ def ws_connect(self, url, *, def _ws_connect(self, url, *, protocols=(), timeout=10.0, + receive_timeout=None, autoclose=True, autoping=True, auth=None, @@ -394,7 +405,9 @@ def _ws_connect(self, url, *, timeout, autoclose, autoping, - self._loop) + self._loop, + time_service=self.time_service, + receive_timeout=receive_timeout) def _prepare_headers(self, headers): """ Add default headers and transform it to CIMultiDict diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 6b09e0300eb..a6abf82fd4e 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -13,17 +13,20 @@ class ClientWebSocketResponse: def __init__(self, reader, writer, protocol, - response, timeout, autoclose, autoping, loop): + response, timeout, autoclose, autoping, loop, *, + time_service=None, receive_timeout=None): self._response = response self._conn = response.connection self._writer = writer self._reader = reader self._protocol = protocol + self._time_service = time_service self._closed = False self._closing = False self._close_code = None self._timeout = timeout + self._receive_timeout = receive_timeout self._autoclose = autoclose self._autoping = autoping self._loop = loop @@ -115,7 +118,7 @@ def close(self, *, code=1000, message=b''): return False @asyncio.coroutine - def receive(self): + def receive(self, timeout=None): if self._waiting: raise RuntimeError('Concurrent call to receive() is not allowed') @@ -126,7 +129,9 @@ def receive(self): return CLOSED_MESSAGE try: - msg = yield from self._reader.read() + with self._time_service.timeout( + timeout or self._receive_timeout): + msg = yield from self._reader.read() except (asyncio.CancelledError, asyncio.TimeoutError): raise except WebSocketError as exc: @@ -156,8 +161,8 @@ def receive(self): self._waiting = False @asyncio.coroutine - def receive_str(self): - msg = yield from self.receive() + def receive_str(self, *, timeout=None): + msg = yield from self.receive(timeout) if msg.type != WSMsgType.TEXT: raise TypeError( "Received message {}:{!r} is not str".format(msg.type, @@ -165,8 +170,8 @@ def receive_str(self): return msg.data @asyncio.coroutine - def receive_bytes(self): - msg = yield from self.receive() + def receive_bytes(self, *, timeout=None): + msg = yield from self.receive(timeout) if msg.type != WSMsgType.BINARY: raise TypeError( "Received message {}:{!r} is not bytes".format(msg.type, @@ -174,8 +179,8 @@ def receive_bytes(self): return msg.data @asyncio.coroutine - def receive_json(self, *, loads=json.loads): - data = yield from self.receive_str() + def receive_json(self, *, loads=json.loads, timeout=None): + data = yield from self.receive_str(timeout=timeout) return loads(data) if PY_35: diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 8e012282471..cebc86f7a71 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -691,8 +691,6 @@ class LowresTimeout: """ Low resolution timeout context manager """ def __init__(self, timeout, time_service, loop): - assert timeout is not None - self._loop = loop self._timeout = timeout self._time_service = time_service @@ -712,14 +710,14 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): + self._task = None + if exc_type is asyncio.CancelledError and self._cancelled: self._cancel_handler = None - self._task = None raise asyncio.TimeoutError from None if self._timeout is not None: self._cancel_handler.cancel() self._cancel_handler = None - self._task = None def _cancel_task(self): self._cancelled = self._task.cancel() diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index 182e2af5fe7..1cc60097fec 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -8,6 +8,7 @@ import sys import unittest from abc import ABC, abstractmethod +from contextlib import contextmanager from unittest import mock from multidict import CIMultiDict @@ -513,6 +514,13 @@ def make_mocked_request(method, path, headers=None, *, time_service.time.return_value = 12345 time_service.strtime.return_value = "Tue, 15 Nov 1994 08:12:31 GMT" + @contextmanager + def timeout(*args, **kw): + yield + + time_service.timeout = mock.Mock() + time_service.timeout.side_effect = timeout + task = mock.Mock() req = Request(message, payload, diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index b81960bfe23..bff1fa72fc4 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -32,7 +32,8 @@ def __bool__(self): class WebSocketResponse(StreamResponse): def __init__(self, *, - timeout=10.0, autoclose=True, autoping=True, protocols=()): + timeout=10.0, receive_timeout=None, + autoclose=True, autoping=True, protocols=()): super().__init__(status=101) self._protocols = protocols self._protocol = None @@ -46,8 +47,10 @@ def __init__(self, *, self._waiting = False self._exception = None self._timeout = timeout + self._receive_timeout = receive_timeout self._autoclose = autoclose self._autoping = autoping + self._time_service = None @asyncio.coroutine def prepare(self, request): @@ -75,6 +78,8 @@ def _pre_start(self, request): else: # pragma: no cover raise HTTPInternalServerError() from err + self._time_service = request.time_service + if self.status != status: self.set_status(status) for k, v in headers: @@ -224,7 +229,7 @@ def close(self, *, code=1000, message=b''): return False @asyncio.coroutine - def receive(self): + def receive(self, timeout=None): if self._reader is None: raise RuntimeError('Call .prepare() first') if self._waiting: @@ -240,7 +245,9 @@ def receive(self): return CLOSED_MESSAGE try: - msg = yield from self._reader.read() + with self._time_service.timeout( + timeout or self._receive_timeout): + msg = yield from self._reader.read() except (asyncio.CancelledError, asyncio.TimeoutError): raise except WebSocketError as exc: @@ -281,8 +288,8 @@ def receive_msg(self): return (yield from self.receive()) @asyncio.coroutine - def receive_str(self): - msg = yield from self.receive() + def receive_str(self, *, timeout=None): + msg = yield from self.receive(timeout) if msg.type != WSMsgType.TEXT: raise TypeError( "Received message {}:{!r} is not str".format(msg.type, @@ -290,8 +297,8 @@ def receive_str(self): return msg.data @asyncio.coroutine - def receive_bytes(self): - msg = yield from self.receive() + def receive_bytes(self, *, timeout=None): + msg = yield from self.receive(timeout) if msg.type != WSMsgType.BINARY: raise TypeError( "Received message {}:{!r} is not bytes".format(msg.type, @@ -299,8 +306,8 @@ def receive_bytes(self): return msg.data @asyncio.coroutine - def receive_json(self, *, loads=json.loads): - data = yield from self.receive_str() + def receive_json(self, *, loads=json.loads, timeout=None): + data = yield from self.receive_str(timeout=timeout) return loads(data) def write(self, data): diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 42cbe5c0644..3598b697b96 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -386,6 +386,7 @@ The client session supports the context manager protocol for self closing. URLs may be either :class:`str` or :class:`~yarl.URL` .. comethod:: ws_connect(url, *, protocols=(), timeout=10.0,\ + receive_timeout=None,\ auth=None,\ autoclose=True,\ autoping=True,\ @@ -401,8 +402,11 @@ The client session supports the context manager protocol for self closing. :param tuple protocols: Websocket protocols - :param float timeout: Timeout for websocket read. 10 seconds by default + :param float timeout: Timeout for websocket to close. 10 seconds by default + :param float receive_timeout: Timeout for websocket to receive complete message. + None(unlimited) seconds by default + :param aiohttp.BasicAuth auth: an object that represents HTTP Basic Authorization (optional) diff --git a/docs/web_reference.rst b/docs/web_reference.rst index efba24ac9e0..2572d2e54d0 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -857,6 +857,12 @@ WebSocketResponse communicate with websocket client by :meth:`send_str`, :meth:`receive` and others. + .. versionadded:: 1.3.0 + + To enable back-pressure from slow websocket clients treat methods + `ping()`, `pong()`, `send_str()`, `send_bytes()`, `send_json()` as coroutines. + By default write buffer size is set to 64k. + :param bool autoping: Automatically send :const:`~aiohttp.WSMsgType.PONG` on :const:`~aiohttp.WSMsgType.PING` @@ -868,6 +874,11 @@ WebSocketResponse requests, you need to do this explicitly using :meth:`ping` method. + .. versionadded:: 1.3.0 + + :param float receive_timeout: Timeout value for `receive` operations. + Default value is None (no timeout for receive operation) + .. versionadded:: 0.19 The class supports ``async for`` statement for iterating over @@ -1033,7 +1044,7 @@ WebSocketResponse :raise RuntimeError: if connection is not started or closing - .. coroutinemethod:: receive() + .. coroutinemethod:: receive(timeout=None) A :ref:`coroutine` that waits upcoming *data* message from peer and returns it. @@ -1050,13 +1061,16 @@ WebSocketResponse Can only be called by the request handling task. + :param timeout: timeout for `receive` operation. + timeout value overrides response`s receive_timeout attribute. + :return: :class:`~aiohttp.WSMessage` :raise RuntimeError: if connection is not started :raise: :exc:`~aiohttp.errors.WSClientDisconnectedError` on closing. - .. coroutinemethod:: receive_str() + .. coroutinemethod:: receive_str(*, timeout=None) A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is @@ -1066,11 +1080,14 @@ WebSocketResponse Can only be called by the request handling task. + :param timeout: timeout for `receive` operation. + timeout value overrides response`s receive_timeout attribute. + :return str: peer's message content. :raise TypeError: if message is :const:`~aiohttp.WSMsgType.BINARY`. - .. coroutinemethod:: receive_bytes() + .. coroutinemethod:: receive_bytes(*, timeout=None) A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is @@ -1080,11 +1097,14 @@ WebSocketResponse Can only be called by the request handling task. + :param timeout: timeout for `receive` operation. + timeout value overrides response`s receive_timeout attribute. + :return bytes: peer's message content. :raise TypeError: if message is :const:`~aiohttp.WSMsgType.TEXT`. - .. coroutinemethod:: receive_json(*, loads=json.loads) + .. coroutinemethod:: receive_json(*, loads=json.loads, timeout=None) A :ref:`coroutine` that calls :meth:`receive_str` and loads the JSON string to a Python dict. @@ -1098,6 +1118,9 @@ WebSocketResponse with parsed JSON (:func:`json.loads` by default). + :param timeout: timeout for `receive` operation. + timeout value overrides response`s receive_timeout attribute. + :return dict: loaded JSON content :raise TypeError: if message is :const:`~aiohttp.WSMsgType.BINARY`. @@ -1107,16 +1130,6 @@ WebSocketResponse .. seealso:: :ref:`WebSockets handling` - - -WebSocketResponse Send Flow Control -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To enable send flow control you need to treat methods -`ping()`, `pong()`, `send_str()`, `send_bytes()`, `send_json()` as coroutines. -By default write buffer size is set to 64k. - -.. versionadded:: 1.3.0 WebSocketReady diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 16cd4efcde7..936bc8ccdfd 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -469,3 +469,53 @@ def handler(request): yield from resp.receive() yield from resp.close() + + +@asyncio.coroutine +def test_receive_timeout(loop, test_client): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + yield from ws.receive_str() + yield from asyncio.sleep(1.1, loop=request.app.loop) + yield from ws.close() + return ws + + app = web.Application(loop=loop) + app.router.add_route('GET', '/', handler) + + client = yield from test_client(app) + resp = yield from client.ws_connect('/', receive_timeout=0.1) + resp.send_str('ask') + + with pytest.raises(asyncio.TimeoutError): + yield from resp.receive(0.1) + + yield from resp.close() + + +@asyncio.coroutine +def test_custom_receive_timeout(loop, test_client): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + yield from ws.receive_str() + yield from asyncio.sleep(1.1, loop=request.app.loop) + yield from ws.close() + return ws + + app = web.Application(loop=loop) + app.router.add_route('GET', '/', handler) + + client = yield from test_client(app) + resp = yield from client.ws_connect('/') + resp.send_str('ask') + + with pytest.raises(asyncio.TimeoutError): + yield from resp.receive(0.1) + + yield from resp.close() diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index f0c01b8f461..1e3e0b9c824 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -565,3 +565,59 @@ def handler(request): ws = yield from client.ws_connect('/') ws.send_bytes(b'data') yield from ws.close() + + +@asyncio.coroutine +def test_receive_timeout(loop, test_client): + raised = False + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(receive_timeout=1.0) + yield from ws.prepare(request) + + try: + yield from ws.receive() + except asyncio.TimeoutError: + nonlocal raised + raised = True + + yield from ws.close() + return ws + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + ws = yield from client.ws_connect('/') + yield from asyncio.sleep(1.06, loop=loop) + yield from ws.close() + assert raised + + +@asyncio.coroutine +def test_custom_receive_timeout(loop, test_client): + raised = False + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(receive_timeout=None) + yield from ws.prepare(request) + + try: + yield from ws.receive(1.0) + except asyncio.TimeoutError: + nonlocal raised + raised = True + + yield from ws.close() + return ws + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + ws = yield from client.ws_connect('/') + yield from asyncio.sleep(1.06, loop=loop) + yield from ws.close() + assert raised