Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client timeouts #2972

Merged
merged 18 commits into from
May 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/2768.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement ``ClientTimeout`` class and support socket read timeout.
84 changes: 69 additions & 15 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
from collections.abc import Coroutine

import attr
from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
from yarl import URL

Expand Down Expand Up @@ -38,11 +39,33 @@
__all__ = (client_exceptions.__all__ + # noqa
client_reqrep.__all__ + # noqa
connector_mod.__all__ + # noqa
('ClientSession', 'ClientWebSocketResponse', 'request'))
('ClientSession', 'ClientTimeout',
'ClientWebSocketResponse', 'request'))


# 5 Minute default read and connect timeout
DEFAULT_TIMEOUT = 5 * 60
@attr.s(frozen=True, slots=True)
class ClientTimeout:
total = attr.ib(type=float, default=None)
connect = attr.ib(type=float, default=None)
sock_read = attr.ib(type=float, default=None)
sock_connect = attr.ib(type=float, default=None)

# pool_queue_timeout = attr.ib(type=float, default=None)
# dns_resolution_timeout = attr.ib(type=float, default=None)
# socket_connect_timeout = attr.ib(type=float, default=None)
# connection_acquiring_timeout = attr.ib(type=float, default=None)
# new_connection_timeout = attr.ib(type=float, default=None)
# http_header_timeout = attr.ib(type=float, default=None)
# response_body_timeout = attr.ib(type=float, default=None)

# to create a timeout specific for a single request, either
# - create a completely new one to overwrite the default
# - or use http://www.attrs.org/en/stable/api.html#attr.evolve
# to overwrite the defaults


# 5 Minute default read timeout
DEFAULT_TIMEOUT = ClientTimeout(total=5*60)


class ClientSession:
Expand All @@ -52,8 +75,8 @@ class ClientSession:
'_source_traceback', '_connector',
'requote_redirect_url', '_loop', '_cookie_jar',
'_connector_owner', '_default_auth',
'_version', '_json_serialize', '_read_timeout',
'_conn_timeout', '_raise_for_status', '_auto_decompress',
'_version', '_json_serialize',
'_timeout', '_raise_for_status', '_auto_decompress',
'_trust_env', '_default_headers', '_skip_auto_headers',
'_request_class', '_response_class',
'_ws_response_class', '_trace_configs'])
Expand All @@ -71,6 +94,7 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
version=http.HttpVersion11,
cookie_jar=None, connector_owner=True, raise_for_status=False,
read_timeout=sentinel, conn_timeout=None,
timeout=sentinel,
auto_decompress=True, trust_env=False,
trace_configs=None):

Expand Down Expand Up @@ -117,9 +141,26 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
self._default_auth = auth
self._version = version
self._json_serialize = json_serialize
self._read_timeout = (read_timeout if read_timeout is not sentinel
else DEFAULT_TIMEOUT)
self._conn_timeout = conn_timeout
if timeout is not sentinel:
self._timeout = timeout
else:
self._timeout = DEFAULT_TIMEOUT
if read_timeout is not sentinel:
if timeout is not sentinel:
raise ValueError("read_timeout and timeout parameters "
"conflict, please setup "
"timeout.read")
else:
self._timeout = attr.evolve(self._timeout,
total=read_timeout)
if conn_timeout is not None:
if timeout is not sentinel:
raise ValueError("conn_timeout and timeout parameters "
"conflict, please setup "
"timeout.connect")
else:
self._timeout = attr.evolve(self._timeout,
connect=conn_timeout)
self._raise_for_status = raise_for_status
self._auto_decompress = auto_decompress
self._trust_env = trust_env
Expand Down Expand Up @@ -244,11 +285,14 @@ async def _request(self, method, url, *,
except ValueError:
raise InvalidURL(proxy)

if timeout is sentinel:
timeout = self._timeout
else:
if not isinstance(timeout, ClientTimeout):
timeout = ClientTimeout(total=timeout)
# timeout is cumulative for all request operations
# (request, redirects, responses, data consuming)
tm = TimeoutHandle(
self._loop,
timeout if timeout is not sentinel else self._read_timeout)
tm = TimeoutHandle(self._loop, timeout.total)
handle = tm.start()

traces = [
Expand Down Expand Up @@ -309,15 +353,17 @@ async def _request(self, method, url, *,
expect100=expect100, loop=self._loop,
response_class=self._response_class,
proxy=proxy, proxy_auth=proxy_auth, timer=timer,
session=self, auto_decompress=self._auto_decompress,
session=self,
ssl=ssl, proxy_headers=proxy_headers, traces=traces)

# connection timeout
try:
with CeilTimeout(self._conn_timeout, loop=self._loop):
with CeilTimeout(self._timeout.connect,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this still means that the connect timeout includes the time waiting for a connector from the pool no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!
Let's keep ClientTimeout.connect as is for backward compatibility but implement .sock_connect for actual connection establishment time.

loop=self._loop):
conn = await self._connector.connect(
req,
traces=traces
traces=traces,
timeout=timeout
)
except asyncio.TimeoutError as exc:
raise ServerTimeoutError(
Expand All @@ -326,11 +372,19 @@ async def _request(self, method, url, *,

tcp_nodelay(conn.transport, True)
tcp_cork(conn.transport, False)

conn.protocol.set_response_params(
timer=timer,
skip_payload=method.upper() == 'HEAD',
read_until_eof=read_until_eof,
auto_decompress=self._auto_decompress,
read_timeout=timeout.sock_read)

try:
try:
resp = await req.send(conn)
try:
await resp.start(conn, read_until_eof)
await resp.start(conn)
except BaseException:
resp.close()
raise
Expand Down
55 changes: 49 additions & 6 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .base_protocol import BaseProtocol
from .client_exceptions import (ClientOSError, ClientPayloadError,
ServerDisconnectedError)
ServerDisconnectedError, ServerTimeoutError)
from .http import HttpResponseParser
from .streams import EMPTY_PAYLOAD, DataQueue

Expand All @@ -16,7 +16,6 @@ def __init__(self, *, loop=None):

self._should_close = False

self._message = None
self._payload = None
self._skip_payload = False
self._payload_parser = None
Expand All @@ -28,6 +27,9 @@ def __init__(self, *, loop=None):
self._upgraded = False
self._parser = None

self._read_timeout = None
self._read_timeout_handle = None

@property
def upgraded(self):
return self._upgraded
Expand Down Expand Up @@ -55,6 +57,8 @@ def is_connected(self):
return self.transport is not None

def connection_lost(self, exc):
self._drop_timeout()

if self._payload_parser is not None:
with suppress(Exception):
self._payload_parser.feed_eof()
Expand All @@ -78,15 +82,15 @@ def connection_lost(self, exc):

self._should_close = True
self._parser = None
self._message = None
self._payload = None
self._payload_parser = None
self._reading_paused = False

super().connection_lost(exc)

def eof_received(self):
pass
# should call parser.feed_eof() most likely
self._drop_timeout()

def pause_reading(self):
if not self._reading_paused:
Expand All @@ -95,6 +99,7 @@ def pause_reading(self):
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = True
self._drop_timeout()

def resume_reading(self):
if self._reading_paused:
Expand All @@ -103,24 +108,33 @@ def resume_reading(self):
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = False
self._reschedule_timeout()

def set_exception(self, exc):
self._should_close = True
self._drop_timeout()
super().set_exception(exc)

def set_parser(self, parser, payload):
self._payload = payload
self._payload_parser = parser

self._drop_timeout()

if self._tail:
data, self._tail = self._tail, b''
self.data_received(data)

def set_response_params(self, *, timer=None,
skip_payload=False,
read_until_eof=False,
auto_decompress=True):
auto_decompress=True,
read_timeout=None):
self._skip_payload = skip_payload

self._read_timeout = read_timeout
self._reschedule_timeout()

self._parser = HttpResponseParser(
self, self._loop, timer=timer,
payload_exception=ClientPayloadError,
Expand All @@ -131,6 +145,26 @@ def set_response_params(self, *, timer=None,
data, self._tail = self._tail, b''
self.data_received(data)

def _drop_timeout(self):
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
self._read_timeout_handle = None

def _reschedule_timeout(self):
timeout = self._read_timeout
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()

if timeout:
self._read_timeout_handle = self._loop.call_later(
timeout, self._on_read_timeout)
else:
self._read_timeout_handle = None

def _on_read_timeout(self):
self.set_exception(
ServerTimeoutError("Timeout on reading data from socket"))

def data_received(self, data):
if not data:
return
Expand Down Expand Up @@ -161,17 +195,26 @@ def data_received(self, data):

self._upgraded = upgraded

payload = None
for message, payload in messages:
if message.should_close:
self._should_close = True

self._message = message
self._payload = payload

if self._skip_payload or message.code in (204, 304):
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediatelly for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()

if tail:
if upgraded:
Expand Down
17 changes: 4 additions & 13 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(self, method, url, *,
chunked=None, expect100=False,
loop=None, response_class=None,
proxy=None, proxy_auth=None,
timer=None, session=None, auto_decompress=True,
timer=None, session=None,
ssl=None,
proxy_headers=None,
traces=None):
Expand All @@ -214,7 +214,6 @@ def __init__(self, method, url, *,
self.length = None
self.response_class = response_class or ClientResponse
self._timer = timer if timer is not None else TimerNoop()
self._auto_decompress = auto_decompress
self._ssl = ssl

if loop.get_debug():
Expand Down Expand Up @@ -551,7 +550,6 @@ async def send(self, conn):
self.method, self.original_url,
writer=self._writer, continue100=self._continue, timer=self._timer,
request_info=self.request_info,
auto_decompress=self._auto_decompress,
traces=self._traces,
loop=self.loop,
session=self._session
Expand Down Expand Up @@ -597,7 +595,7 @@ class ClientResponse(HeadersMixin):

def __init__(self, method, url, *,
writer, continue100, timer,
request_info, auto_decompress,
request_info,
traces, loop, session):
assert isinstance(url, URL)

Expand All @@ -614,7 +612,6 @@ def __init__(self, method, url, *,
self._history = ()
self._request_info = request_info
self._timer = timer if timer is not None else TimerNoop()
self._auto_decompress = auto_decompress # True by default
self._cache = {} # required for @reify method decorator
self._traces = traces
self._loop = loop
Expand Down Expand Up @@ -735,23 +732,17 @@ def links(self):

return MultiDictProxy(links)

async def start(self, connection, read_until_eof=False):
async def start(self, connection):
"""Start response processing."""
self._closed = False
self._protocol = connection.protocol
self._connection = connection

connection.protocol.set_response_params(
timer=self._timer,
skip_payload=self.method.lower() == 'head',
read_until_eof=read_until_eof,
auto_decompress=self._auto_decompress)

with self._timer:
while True:
# read response
try:
(message, payload) = await self._protocol.read()
message, payload = await self._protocol.read()
except http.HttpProcessingError as exc:
raise ClientResponseError(
self.request_info, self.history,
Expand Down
Loading