diff --git a/CHANGES/2768.feature b/CHANGES/2768.feature new file mode 100644 index 00000000000..0418e638db0 --- /dev/null +++ b/CHANGES/2768.feature @@ -0,0 +1 @@ +Implement ``ClientTimeout`` class and support socket read timeout. \ No newline at end of file diff --git a/aiohttp/client.py b/aiohttp/client.py index 45f2737e33a..d70b5f7a966 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -10,6 +10,7 @@ import warnings from collections.abc import Coroutine +import attr from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr from yarl import URL @@ -38,11 +39,33 @@ __all__ = (client_exceptions.__all__ + # noqa client_reqrep.__all__ + # noqa connector_mod.__all__ + # noqa - ('ClientSession', 'ClientWebSocketResponse', 'request')) + ('ClientSession', 'ClientTimeout', + 'ClientWebSocketResponse', 'request')) -# 5 Minute default read and connect timeout -DEFAULT_TIMEOUT = 5 * 60 +@attr.s(frozen=True, slots=True) +class ClientTimeout: + total = attr.ib(type=float, default=None) + connect = attr.ib(type=float, default=None) + sock_read = attr.ib(type=float, default=None) + sock_connect = attr.ib(type=float, default=None) + + # pool_queue_timeout = attr.ib(type=float, default=None) + # dns_resolution_timeout = attr.ib(type=float, default=None) + # socket_connect_timeout = attr.ib(type=float, default=None) + # connection_acquiring_timeout = attr.ib(type=float, default=None) + # new_connection_timeout = attr.ib(type=float, default=None) + # http_header_timeout = attr.ib(type=float, default=None) + # response_body_timeout = attr.ib(type=float, default=None) + + # to create a timeout specific for a single request, either + # - create a completely new one to overwrite the default + # - or use http://www.attrs.org/en/stable/api.html#attr.evolve + # to overwrite the defaults + + +# 5 Minute default read timeout +DEFAULT_TIMEOUT = ClientTimeout(total=5*60) class ClientSession: @@ -52,8 +75,8 @@ class ClientSession: '_source_traceback', '_connector', 'requote_redirect_url', '_loop', '_cookie_jar', '_connector_owner', '_default_auth', - '_version', '_json_serialize', '_read_timeout', - '_conn_timeout', '_raise_for_status', '_auto_decompress', + '_version', '_json_serialize', + '_timeout', '_raise_for_status', '_auto_decompress', '_trust_env', '_default_headers', '_skip_auto_headers', '_request_class', '_response_class', '_ws_response_class', '_trace_configs']) @@ -71,6 +94,7 @@ def __init__(self, *, connector=None, loop=None, cookies=None, version=http.HttpVersion11, cookie_jar=None, connector_owner=True, raise_for_status=False, read_timeout=sentinel, conn_timeout=None, + timeout=sentinel, auto_decompress=True, trust_env=False, trace_configs=None): @@ -117,9 +141,26 @@ def __init__(self, *, connector=None, loop=None, cookies=None, self._default_auth = auth self._version = version self._json_serialize = json_serialize - self._read_timeout = (read_timeout if read_timeout is not sentinel - else DEFAULT_TIMEOUT) - self._conn_timeout = conn_timeout + if timeout is not sentinel: + self._timeout = timeout + else: + self._timeout = DEFAULT_TIMEOUT + if read_timeout is not sentinel: + if timeout is not sentinel: + raise ValueError("read_timeout and timeout parameters " + "conflict, please setup " + "timeout.read") + else: + self._timeout = attr.evolve(self._timeout, + total=read_timeout) + if conn_timeout is not None: + if timeout is not sentinel: + raise ValueError("conn_timeout and timeout parameters " + "conflict, please setup " + "timeout.connect") + else: + self._timeout = attr.evolve(self._timeout, + connect=conn_timeout) self._raise_for_status = raise_for_status self._auto_decompress = auto_decompress self._trust_env = trust_env @@ -244,11 +285,14 @@ async def _request(self, method, url, *, except ValueError: raise InvalidURL(proxy) + if timeout is sentinel: + timeout = self._timeout + else: + if not isinstance(timeout, ClientTimeout): + timeout = ClientTimeout(total=timeout) # timeout is cumulative for all request operations # (request, redirects, responses, data consuming) - tm = TimeoutHandle( - self._loop, - timeout if timeout is not sentinel else self._read_timeout) + tm = TimeoutHandle(self._loop, timeout.total) handle = tm.start() traces = [ @@ -309,15 +353,17 @@ async def _request(self, method, url, *, expect100=expect100, loop=self._loop, response_class=self._response_class, proxy=proxy, proxy_auth=proxy_auth, timer=timer, - session=self, auto_decompress=self._auto_decompress, + session=self, ssl=ssl, proxy_headers=proxy_headers, traces=traces) # connection timeout try: - with CeilTimeout(self._conn_timeout, loop=self._loop): + with CeilTimeout(self._timeout.connect, + loop=self._loop): conn = await self._connector.connect( req, - traces=traces + traces=traces, + timeout=timeout ) except asyncio.TimeoutError as exc: raise ServerTimeoutError( @@ -326,11 +372,19 @@ async def _request(self, method, url, *, tcp_nodelay(conn.transport, True) tcp_cork(conn.transport, False) + + conn.protocol.set_response_params( + timer=timer, + skip_payload=method.upper() == 'HEAD', + read_until_eof=read_until_eof, + auto_decompress=self._auto_decompress, + read_timeout=timeout.sock_read) + try: try: resp = await req.send(conn) try: - await resp.start(conn, read_until_eof) + await resp.start(conn) except BaseException: resp.close() raise diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index f9fa560195b..91fa45a2d3a 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -2,7 +2,7 @@ from .base_protocol import BaseProtocol from .client_exceptions import (ClientOSError, ClientPayloadError, - ServerDisconnectedError) + ServerDisconnectedError, ServerTimeoutError) from .http import HttpResponseParser from .streams import EMPTY_PAYLOAD, DataQueue @@ -16,7 +16,6 @@ def __init__(self, *, loop=None): self._should_close = False - self._message = None self._payload = None self._skip_payload = False self._payload_parser = None @@ -28,6 +27,9 @@ def __init__(self, *, loop=None): self._upgraded = False self._parser = None + self._read_timeout = None + self._read_timeout_handle = None + @property def upgraded(self): return self._upgraded @@ -55,6 +57,8 @@ def is_connected(self): return self.transport is not None def connection_lost(self, exc): + self._drop_timeout() + if self._payload_parser is not None: with suppress(Exception): self._payload_parser.feed_eof() @@ -78,7 +82,6 @@ def connection_lost(self, exc): self._should_close = True self._parser = None - self._message = None self._payload = None self._payload_parser = None self._reading_paused = False @@ -86,7 +89,8 @@ def connection_lost(self, exc): super().connection_lost(exc) def eof_received(self): - pass + # should call parser.feed_eof() most likely + self._drop_timeout() def pause_reading(self): if not self._reading_paused: @@ -95,6 +99,7 @@ def pause_reading(self): except (AttributeError, NotImplementedError, RuntimeError): pass self._reading_paused = True + self._drop_timeout() def resume_reading(self): if self._reading_paused: @@ -103,15 +108,19 @@ def resume_reading(self): except (AttributeError, NotImplementedError, RuntimeError): pass self._reading_paused = False + self._reschedule_timeout() def set_exception(self, exc): self._should_close = True + self._drop_timeout() super().set_exception(exc) def set_parser(self, parser, payload): self._payload = payload self._payload_parser = parser + self._drop_timeout() + if self._tail: data, self._tail = self._tail, b'' self.data_received(data) @@ -119,8 +128,13 @@ def set_parser(self, parser, payload): def set_response_params(self, *, timer=None, skip_payload=False, read_until_eof=False, - auto_decompress=True): + auto_decompress=True, + read_timeout=None): self._skip_payload = skip_payload + + self._read_timeout = read_timeout + self._reschedule_timeout() + self._parser = HttpResponseParser( self, self._loop, timer=timer, payload_exception=ClientPayloadError, @@ -131,6 +145,26 @@ def set_response_params(self, *, timer=None, data, self._tail = self._tail, b'' self.data_received(data) + def _drop_timeout(self): + if self._read_timeout_handle is not None: + self._read_timeout_handle.cancel() + self._read_timeout_handle = None + + def _reschedule_timeout(self): + timeout = self._read_timeout + if self._read_timeout_handle is not None: + self._read_timeout_handle.cancel() + + if timeout: + self._read_timeout_handle = self._loop.call_later( + timeout, self._on_read_timeout) + else: + self._read_timeout_handle = None + + def _on_read_timeout(self): + self.set_exception( + ServerTimeoutError("Timeout on reading data from socket")) + def data_received(self, data): if not data: return @@ -161,17 +195,26 @@ def data_received(self, data): self._upgraded = upgraded + payload = None for message, payload in messages: if message.should_close: self._should_close = True - self._message = message self._payload = payload if self._skip_payload or message.code in (204, 304): self.feed_data((message, EMPTY_PAYLOAD), 0) else: self.feed_data((message, payload), 0) + if payload is not None: + # new message(s) was processed + # register timeout handler unsubscribing + # either on end-of-stream or immediatelly for + # EMPTY_PAYLOAD + if payload is not EMPTY_PAYLOAD: + payload.on_eof(self._drop_timeout) + else: + self._drop_timeout() if tail: if upgraded: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 4925aa8ce4e..57ba2778b3f 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -189,7 +189,7 @@ def __init__(self, method, url, *, chunked=None, expect100=False, loop=None, response_class=None, proxy=None, proxy_auth=None, - timer=None, session=None, auto_decompress=True, + timer=None, session=None, ssl=None, proxy_headers=None, traces=None): @@ -214,7 +214,6 @@ def __init__(self, method, url, *, self.length = None self.response_class = response_class or ClientResponse self._timer = timer if timer is not None else TimerNoop() - self._auto_decompress = auto_decompress self._ssl = ssl if loop.get_debug(): @@ -551,7 +550,6 @@ async def send(self, conn): self.method, self.original_url, writer=self._writer, continue100=self._continue, timer=self._timer, request_info=self.request_info, - auto_decompress=self._auto_decompress, traces=self._traces, loop=self.loop, session=self._session @@ -597,7 +595,7 @@ class ClientResponse(HeadersMixin): def __init__(self, method, url, *, writer, continue100, timer, - request_info, auto_decompress, + request_info, traces, loop, session): assert isinstance(url, URL) @@ -614,7 +612,6 @@ def __init__(self, method, url, *, self._history = () self._request_info = request_info self._timer = timer if timer is not None else TimerNoop() - self._auto_decompress = auto_decompress # True by default self._cache = {} # required for @reify method decorator self._traces = traces self._loop = loop @@ -735,23 +732,17 @@ def links(self): return MultiDictProxy(links) - async def start(self, connection, read_until_eof=False): + async def start(self, connection): """Start response processing.""" self._closed = False self._protocol = connection.protocol self._connection = connection - connection.protocol.set_response_params( - timer=self._timer, - skip_payload=self.method.lower() == 'head', - read_until_eof=read_until_eof, - auto_decompress=self._auto_decompress) - with self._timer: while True: # read response try: - (message, payload) = await self._protocol.read() + message, payload = await self._protocol.read() except http.HttpProcessingError as exc: raise ClientResponseError( self.request_info, self.history, diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 6a17a43e13a..ab4e8feb18a 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -22,7 +22,7 @@ ssl_errors) from .client_proto import ResponseHandler from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params -from .helpers import PY_36, is_ip_address, noop, sentinel +from .helpers import PY_36, CeilTimeout, is_ip_address, noop, sentinel from .locks import EventResultOrError from .resolver import DefaultResolver @@ -391,7 +391,7 @@ def _available_connections(self, key): return available - async def connect(self, req, traces=None): + async def connect(self, req, traces, timeout): """Get from pool or create new connection.""" key = req.connection_key available = self._available_connections(key) @@ -442,10 +442,7 @@ async def connect(self, req, traces=None): await trace.send_connection_create_start() try: - proto = await self._create_connection( - req, - traces=traces - ) + proto = await self._create_connection(req, traces, timeout) if self._closed: proto.close() raise ClientConnectionError("Connector is closed.") @@ -561,7 +558,7 @@ def _release(self, key, protocol, *, should_close=False): self._cleanup_handle = helpers.weakref_handle( self, '_cleanup', self._keepalive_timeout, self._loop) - async def _create_connection(self, req, traces=None): + async def _create_connection(self, req, traces, timeout): raise NotImplementedError() @@ -747,21 +744,17 @@ async def _resolve_host(self, host, port, traces=None): return self._cached_hosts.next_addrs(key) - async def _create_connection(self, req, traces=None): + async def _create_connection(self, req, traces, timeout): """Create connection. Has same keyword arguments as BaseEventLoop.create_connection. """ if req.proxy: _, proto = await self._create_proxy_connection( - req, - traces=traces - ) + req, traces, timeout) else: _, proto = await self._create_direct_connection( - req, - traces=traces - ) + req, traces, timeout) return proto @@ -821,10 +814,12 @@ def _get_fingerprint(self, req): return None async def _wrap_create_connection(self, *args, - req, client_error=ClientConnectorError, + req, timeout, + client_error=ClientConnectorError, **kwargs): try: - return await self._loop.create_connection(*args, **kwargs) + with CeilTimeout(timeout.sock_connect): + return await self._loop.create_connection(*args, **kwargs) except certificate_errors as exc: raise ClientConnectorCertificateError( req.connection_key, exc) from exc @@ -833,9 +828,8 @@ async def _wrap_create_connection(self, *args, except OSError as exc: raise client_error(req.connection_key, exc) from exc - async def _create_direct_connection(self, req, - *, client_error=ClientConnectorError, - traces=None): + async def _create_direct_connection(self, req, traces, timeout, + *, client_error=ClientConnectorError): sslcontext = self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) @@ -860,7 +854,7 @@ async def _create_direct_connection(self, req, try: transp, proto = await self._wrap_create_connection( - self._factory, host, port, + self._factory, host, port, timeout=timeout, ssl=sslcontext, family=hinfo['family'], proto=hinfo['proto'], flags=hinfo['flags'], server_hostname=hinfo['hostname'] if sslcontext else None, @@ -884,7 +878,7 @@ async def _create_direct_connection(self, req, else: raise last_exc - async def _create_proxy_connection(self, req, traces=None): + async def _create_proxy_connection(self, req, traces, timeout): headers = {} if req.proxy_headers is not None: headers = req.proxy_headers @@ -899,7 +893,7 @@ async def _create_proxy_connection(self, req, traces=None): # create connection to proxy server transport, proto = await self._create_direct_connection( - proxy_req, client_error=ClientProxyConnectionError) + proxy_req, [], timeout, client_error=ClientProxyConnectionError) auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) if auth is not None: @@ -928,7 +922,8 @@ async def _create_proxy_connection(self, req, traces=None): conn = Connection(self, key, proto, self._loop) proxy_resp = await proxy_req.send(conn) try: - resp = await proxy_resp.start(conn, True) + conn._protocol.set_response_params() + resp = await proxy_resp.start(conn) except BaseException: proxy_resp.close() conn.close() @@ -954,7 +949,8 @@ async def _create_proxy_connection(self, req, traces=None): transport.close() transport, proto = await self._wrap_create_connection( - self._factory, ssl=sslcontext, sock=rawsock, + self._factory, timeout=timeout, + ssl=sslcontext, sock=rawsock, server_hostname=req.host, req=req) finally: @@ -987,10 +983,11 @@ def path(self): """Path to unix socket.""" return self._path - async def _create_connection(self, req, traces=None): + async def _create_connection(self, req, traces, timeout): try: - _, proto = await self._loop.create_unix_connection( - self._factory, self._path) + with CeilTimeout(timeout.sock_connect): + _, proto = await self._loop.create_unix_connection( + self._factory, self._path) except OSError as exc: raise ClientConnectorError(req.connection_key, exc) from exc diff --git a/docs/client_quickstart.rst b/docs/client_quickstart.rst index d2619b5be3d..514fbc432df 100644 --- a/docs/client_quickstart.rst +++ b/docs/client_quickstart.rst @@ -394,16 +394,40 @@ multiple writer tasks which can only send data asynchronously (by Timeouts ======== -By default all IO operations have 5min timeout. The timeout may be -overridden by passing ``timeout`` parameter into -:meth:`ClientSession.get` and family:: +Timeout settings a stored in :class:`ClientTimeout` data structure. - async with session.get('https://github.com', timeout=60) as r: +By default *aiohttp* uses a *total* 5min timeout, it means that the +whole operation should finish in 5 minutes. + +The value could be overridden by *timeout* parameter for the session:: + + timeout = aiohttp.ClientTimeout(total=60) + async with aiohttp.ClientSession(timeout=timeout) as session: ... -``None`` or ``0`` disables timeout check. +Timeout could be overridden for a request like :meth:`ClientSession.get`:: -.. note:: + async with session.get(url, timeout=timeout) as resp: + ... + +Supported :class:`ClientTimeout` fields are: + + ``total`` + + The whole operation time including connection + establishment, request sending and response reading. + + ``connect`` + + The maximum time for connection establishment. + + ``sock_read`` + + The maximum allowed timeout for period between reading a new + data portion from a peer. + +All fields a floats, ``None`` or ``0`` disables a particular timeout check. + +Thus the default timeout is:: - Timeout is cumulative time, it includes all operations like sending request, - redirects, response parsing, consuming response, etc. + aiohttp.ClientTimeout(total=5*60, connect=None, sock_read=None) diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 19d87da830e..deb6846a15e 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -45,6 +45,7 @@ The client session supports the context manager protocol for self closing. version=aiohttp.HttpVersion11, \ cookie_jar=None, read_timeout=None, \ conn_timeout=None, \ + timeout=sentinel, \ raise_for_status=False, \ connector_owner=True, \ auto_decompress=True, proxies=None) @@ -113,16 +114,27 @@ The client session supports the context manager protocol for self closing. Automatically call :meth:`ClientResponse.raise_for_status()` for each response, ``False`` by default. - .. versionadded:: 2.0 + :param timeout: a :class:`ClientTimeout` settings structure, 5min + total timeout by default. + + .. versionadded:: 3.3 :param float read_timeout: Request operations timeout. ``read_timeout`` is cumulative for all request operations (request, redirects, responses, data consuming). By default, the read timeout is 5*60 seconds. Use ``None`` or ``0`` to disable timeout checks. + .. deprecated:: 3.3 + + Use ``timeout`` parameter instead. + :param float conn_timeout: timeout for connection establishing (optional). Values ``0`` or ``None`` mean no timeout. + .. deprecated:: 3.3 + + Use ``timeout`` parameter instead. + :param bool connector_owner: Close connector instance on session closing. @@ -197,7 +209,7 @@ The client session supports the context manager protocol for self closing. max_redirects=10,\ compress=None, chunked=None, expect100=False,\ read_until_eof=True, proxy=None, proxy_auth=None,\ - timeout=5*60, ssl=None, \ + timeout=sentinel, ssl=None, \ verify_ssl=None, fingerprint=None, \ ssl_context=None, proxy_headers=None) :async-with: @@ -278,8 +290,15 @@ The client session supports the context manager protocol for self closing. :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP Basic Authorization (optional) - :param int timeout: override the session's timeout - (``read_timeout``) for IO operations. + :param int timeout: override the session's timeout. + + .. versionchanged:: 3.3 + + The parameter is :class:`ClientTimeout` instance, + :class:`float` is still supported for sake of backward + compatibility. + + If :class:`float` is passed it is a *total* timeout. :param ssl: SSL validation mode. ``None`` for default SSL check (:func:`ssl.create_default_context` is used), @@ -1449,12 +1468,53 @@ Utilities --------- +ClientTimeout +^^^^^^^^^^^^^ + +.. class:: ClientTimeout(*, total=None, connect=None, \ + sock_connect, sock_read=None) + + A data class for client timeout settings. + + .. attribute:: total + + Total timeout for the whole request. + + :class:`float`, ``None`` by default. + + .. attribute:: connect + + Total timeout for acquiring a connection from pool. The time + consists connection establishment for a new connection or + waiting for a free connection from a pool if pool connection + limits are exceeded. + + For pure socket connection establishment time use + :attr:`sock_connect`. + + :class:`float`, ``None`` by default. + + .. attribute:: sock_connect + + A timeout for connecting to a peer for a new connection, not + given from a pool. See also :attr:`connect`. + + :class:`float`, ``None`` by default. + + .. attribute:: sock_read + + A timeout for reading a portion of data from a peer. + + :class:`float`, ``None`` by default. + + .. versionadded:: 3.3 + RequestInfo ^^^^^^^^^^^ .. class:: RequestInfo() - A namedtuple with request URL and headers from :class:`ClientRequest` + A data class with request URL and headers from :class:`ClientRequest` object, available as :attr:`ClientResponse.request_info` attribute. .. attribute:: url diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 7fc8af3c430..620741511b8 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -2598,3 +2598,18 @@ async def handler(request): with pytest.raises(aiohttp.ClientConnectionError): await resp.content.readline() + + +async def test_read_timeout(aiohttp_client): + async def handler(request): + await asyncio.sleep(5) + return web.Response() + + app = web.Application() + app.add_routes([web.get('/', handler)]) + + timeout = aiohttp.ClientTimeout(sock_read=0.1) + client = await aiohttp_client(app, timeout=timeout) + + with pytest.raises(aiohttp.ServerTimeoutError): + await client.get('/') diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 9868b560ce8..64c3b2cd7ec 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -33,7 +33,7 @@ async def test_client_proto_bad_message(loop): proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) - proto.set_response_params(read_until_eof=True) + proto.set_response_params() proto.data_received(b'HTTP\r\n\r\n') assert proto.should_close @@ -71,11 +71,11 @@ async def test_client_protocol_readuntil_eof(loop): continue100=None, timer=TimerNoop(), request_info=mock.Mock(), - auto_decompress=True, traces=[], loop=loop, session=mock.Mock()) - await response.start(conn, read_until_eof=True) + proto.set_response_params(read_until_eof=True) + await response.start(conn) assert not response.content.is_eof() @@ -96,3 +96,35 @@ async def test_empty_data(loop): proto.data_received(b'') # do nothing + + +async def test_schedule_timeout(loop): + proto = ResponseHandler(loop=loop) + proto.set_response_params(read_timeout=1) + assert proto._read_timeout_handle is not None + + +async def test_drop_timeout(loop): + proto = ResponseHandler(loop=loop) + proto.set_response_params(read_timeout=1) + assert proto._read_timeout_handle is not None + proto._drop_timeout() + assert proto._read_timeout_handle is None + + +async def test_reschedule_timeout(loop): + proto = ResponseHandler(loop=loop) + proto.set_response_params(read_timeout=1) + assert proto._read_timeout_handle is not None + h = proto._read_timeout_handle + proto._reschedule_timeout() + assert proto._read_timeout_handle is not None + assert proto._read_timeout_handle is not h + + +async def test_eof_received(loop): + proto = ResponseHandler(loop=loop) + proto.set_response_params(read_timeout=1) + assert proto._read_timeout_handle is not None + proto.eof_received() + assert proto._read_timeout_handle is None diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 22374f03fdf..0087fcaec1f 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -1210,7 +1210,6 @@ async def send(self, conn): continue100=self._continue, timer=self._timer, request_info=self.request_info, - auto_decompress=self._auto_decompress, traces=self._traces, loop=self.loop, session=self._session) @@ -1219,7 +1218,7 @@ async def send(self, conn): called = True return resp - async def create_connection(req, traces=None): + async def create_connection(req, traces, timeout): assert isinstance(req, CustomRequest) return mock.Mock() connector = BaseConnector(loop=loop) diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 8eefa4b9fce..9dbf4d05036 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -29,7 +29,6 @@ async def test_http_processing_error(session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -54,7 +53,6 @@ def test_del(session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -79,7 +77,6 @@ def test_close(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -97,7 +94,6 @@ def test_wait_for_100_1(loop, session): request_info=mock.Mock(), writer=mock.Mock(), timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -112,7 +108,6 @@ def test_wait_for_100_2(loop, session): continue100=None, writer=mock.Mock(), timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -126,7 +121,6 @@ def test_repr(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -142,7 +136,6 @@ def test_repr_non_ascii_url(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -156,7 +149,6 @@ def test_repr_non_ascii_reason(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -171,7 +163,6 @@ def test_url_obj_deprecated(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -185,7 +176,6 @@ async def test_read_and_release_connection(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -208,7 +198,6 @@ async def test_read_and_release_connection_with_error(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -227,7 +216,6 @@ async def test_release(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -252,7 +240,6 @@ def run(conn): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -270,7 +257,6 @@ async def test_response_eof(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -289,7 +275,6 @@ async def test_response_eof_upgraded(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -308,7 +293,6 @@ async def test_response_eof_after_connection_detach(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -327,7 +311,6 @@ async def test_text(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -353,7 +336,6 @@ async def test_text_bad_encoding(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -382,7 +364,6 @@ async def test_text_custom_encoding(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -410,7 +391,6 @@ async def test_text_detect_encoding(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -436,7 +416,6 @@ async def test_text_detect_encoding_if_invalid_charset(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -463,7 +442,6 @@ async def test_text_after_read(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -489,7 +467,6 @@ async def test_json(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -515,7 +492,6 @@ async def test_json_extended_content_type(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -542,7 +518,6 @@ async def test_json_custom_content_type(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -568,7 +543,6 @@ async def test_json_custom_loader(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -589,7 +563,6 @@ async def test_json_invalid_content_type(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -609,7 +582,6 @@ async def test_json_no_content(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -627,7 +599,6 @@ async def test_json_override_encoding(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -655,7 +626,6 @@ def test_get_encoding_unknown(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -672,7 +642,6 @@ def test_raise_for_status_2xx(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -687,7 +656,6 @@ def test_raise_for_status_4xx(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -705,7 +673,6 @@ def test_resp_host(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -718,7 +685,6 @@ def test_content_type(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -733,7 +699,6 @@ def test_content_type_no_header(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -748,7 +713,6 @@ def test_charset(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -763,7 +727,6 @@ def test_charset_no_header(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -778,7 +741,6 @@ def test_charset_no_charset(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -793,7 +755,6 @@ def test_content_disposition_full(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -813,7 +774,6 @@ def test_content_disposition_no_parameters(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -830,7 +790,6 @@ def test_content_disposition_no_header(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -845,7 +804,6 @@ def test_content_disposition_cache(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock()) @@ -868,7 +826,6 @@ def test_response_request_info(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock() @@ -892,7 +849,6 @@ def test_request_info_in_exception(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock() @@ -918,7 +874,6 @@ def test_no_redirect_history_in_exception(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock() @@ -948,7 +903,6 @@ def test_redirect_history_in_exception(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock() @@ -967,7 +921,6 @@ def test_redirect_history_in_exception(): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=mock.Mock(), session=mock.Mock() @@ -994,7 +947,6 @@ async def test_response_read_triggers_callback(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, loop=loop, session=session, traces=[trace] @@ -1028,7 +980,6 @@ def test_response_real_url(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -1043,7 +994,6 @@ def test_response_links_comma_separated(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -1073,7 +1023,6 @@ def test_response_links_multiple_headers(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -1106,7 +1055,6 @@ def test_response_links_no_rel(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -1132,7 +1080,6 @@ def test_response_links_quoted(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -1158,7 +1105,6 @@ def test_response_links_relative(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) @@ -1184,7 +1130,6 @@ def test_response_links_empty(loop, session): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=loop, session=session) diff --git a/tests/test_client_session.py b/tests/test_client_session.py index ab4512c3521..047ed796f66 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -367,7 +367,7 @@ async def test_reraise_os_error(create_session): req.send = mock.Mock(side_effect=err) session = create_session(request_class=req_factory) - async def create_connection(req, traces=None): + async def create_connection(req, traces, timeout): # return self.transport, self.protocol return mock.Mock() session._connector._create_connection = create_connection @@ -393,12 +393,12 @@ class UnexpectedException(BaseException): connections = [] original_connect = session._connector.connect - async def connect(req, traces=None): - conn = await original_connect(req, traces=traces) + async def connect(req, traces, timeout): + conn = await original_connect(req, traces, timeout) connections.append(conn) return conn - async def create_connection(req, traces=None): + async def create_connection(req, traces, timeout): # return self.transport, self.protocol conn = mock.Mock() return conn diff --git a/tests/test_connector.py b/tests/test_connector.py index f0ef56a7a4e..75894378264 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -16,7 +16,7 @@ import aiohttp from aiohttp import client, web -from aiohttp.client import ClientRequest +from aiohttp.client import ClientRequest, ClientTimeout from aiohttp.client_reqrep import ConnectionKey from aiohttp.connector import Connection, _DNSCacheTable from aiohttp.test_utils import make_mocked_coro, unused_port @@ -222,7 +222,7 @@ def test_del_empty_conector(loop): async def test_create_conn(loop): conn = aiohttp.BaseConnector(loop=loop) with pytest.raises(NotImplementedError): - await conn._create_connection(object()) + await conn._create_connection(object(), [], object()) def test_context_manager(loop): @@ -488,7 +488,7 @@ async def certificate_error(*args, **kwargs): conn._loop.create_connection = certificate_error with pytest.raises(aiohttp.ClientConnectorCertificateError) as ctx: - await conn.connect(req) + await conn.connect(req, [], ClientTimeout()) assert isinstance(ctx.value, ssl.CertificateError) assert isinstance(ctx.value.certificate_error, ssl.CertificateError) @@ -551,7 +551,7 @@ async def create_connection(*args, **kwargs): if ip == ip4: fingerprint_error = True - tr, pr = mock.Mock(), None + tr, pr = mock.Mock(), mock.Mock() def get_extra_info(param): if param == 'sslcontext': @@ -572,7 +572,7 @@ def get_extra_info(param): if ip == ip5: connected = True - tr, pr = mock.Mock(), None + tr, pr = mock.Mock(), mock.Mock() def get_extra_info(param): if param == 'sslcontext': @@ -592,7 +592,7 @@ def get_extra_info(param): conn._loop.create_connection = create_connection - await conn.connect(req) + await conn.connect(req, [], ClientTimeout()) assert ips == ips_tried assert os_error @@ -932,10 +932,10 @@ def test_dns_error(loop): req = ClientRequest( 'GET', URL('http://www.python.org'), - loop=loop, - ) + loop=loop) + with pytest.raises(aiohttp.ClientConnectorError): - loop.run_until_complete(connector.connect(req)) + loop.run_until_complete(connector.connect(req, [], ClientTimeout())) def test_get_pop_empty_conns(loop): @@ -1008,7 +1008,7 @@ async def test_connect(loop, key): conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - connection = await conn.connect(req) + connection = await conn.connect(req, [], ClientTimeout()) assert not conn._create_connection.called assert connection._protocol is proto assert connection.transport is proto.transport @@ -1050,7 +1050,7 @@ async def test_connect_tracing(loop): conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - conn2 = await conn.connect(req, traces=traces) + conn2 = await conn.connect(req, traces, ClientTimeout()) conn2.release() on_connection_create_start.assert_called_with( @@ -1076,7 +1076,7 @@ async def test_close_during_connect(loop): conn._create_connection = mock.Mock() conn._create_connection.return_value = fut - task = loop.create_task(conn.connect(req)) + task = loop.create_task(conn.connect(req, None, ClientTimeout())) await asyncio.sleep(0, loop=loop) conn.close() @@ -1379,7 +1379,7 @@ async def test_connect_with_limit(loop, key): conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - connection1 = await conn.connect(req) + connection1 = await conn.connect(req, None, ClientTimeout()) assert connection1._protocol == proto assert 1 == len(conn._acquired) @@ -1391,7 +1391,7 @@ async def test_connect_with_limit(loop, key): async def f(): nonlocal acquired - connection2 = await conn.connect(req) + connection2 = await conn.connect(req, None, ClientTimeout()) acquired = True assert 1 == len(conn._acquired) assert 1 == len(conn._acquired_per_host[key]) @@ -1445,13 +1445,10 @@ async def test_connect_queued_operation_tracing(loop, key): conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - connection1 = await conn.connect(req, traces=traces) + connection1 = await conn.connect(req, traces, ClientTimeout()) async def f(): - connection2 = await conn.connect( - req, - traces=traces - ) + connection2 = await conn.connect(req, traces, ClientTimeout()) on_connection_queued_start.assert_called_with( session, trace_config_ctx, @@ -1500,7 +1497,7 @@ async def test_connect_reuseconn_tracing(loop, key): conn = aiohttp.BaseConnector(loop=loop, limit=1) conn._conns[key] = [(proto, loop.time())] - conn2 = await conn.connect(req, traces=traces) + conn2 = await conn.connect(req, traces, ClientTimeout()) conn2.release() on_connection_reuseconn.assert_called_with( @@ -1524,11 +1521,11 @@ async def test_connect_with_limit_and_limit_per_host(loop, key): conn._create_connection.return_value.set_result(proto) acquired = False - connection1 = await conn.connect(req) + connection1 = await conn.connect(req, None, ClientTimeout()) async def f(): nonlocal acquired - connection2 = await conn.connect(req) + connection2 = await conn.connect(req, None, ClientTimeout()) acquired = True assert 1 == len(conn._acquired) assert 1 == len(conn._acquired_per_host[key]) @@ -1558,11 +1555,11 @@ async def test_connect_with_no_limit_and_limit_per_host(loop, key): conn._create_connection.return_value.set_result(proto) acquired = False - connection1 = await conn.connect(req) + connection1 = await conn.connect(req, None, ClientTimeout()) async def f(): nonlocal acquired - connection2 = await conn.connect(req) + connection2 = await conn.connect(req, None, ClientTimeout()) acquired = True connection2.release() @@ -1590,11 +1587,11 @@ async def test_connect_with_no_limits(loop, key): conn._create_connection.return_value.set_result(proto) acquired = False - connection1 = await conn.connect(req) + connection1 = await conn.connect(req, None, ClientTimeout()) async def f(): nonlocal acquired - connection2 = await conn.connect(req) + connection2 = await conn.connect(req, None, ClientTimeout()) acquired = True assert 1 == len(conn._acquired) assert 1 == len(conn._acquired_per_host[key]) @@ -1623,7 +1620,7 @@ async def test_connect_with_limit_cancelled(loop): conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - connection = await conn.connect(req) + connection = await conn.connect(req, None, ClientTimeout()) assert connection._protocol == proto assert connection.transport == proto.transport @@ -1631,7 +1628,8 @@ async def test_connect_with_limit_cancelled(loop): with pytest.raises(asyncio.TimeoutError): # limit exhausted - await asyncio.wait_for(conn.connect(req), 0.01, loop=loop) + await asyncio.wait_for(conn.connect(req, None, ClientTimeout()), + 0.01, loop=loop) connection.close() @@ -1646,7 +1644,7 @@ def check_with_exc(err): with pytest.raises(Exception): req = mock.Mock() - yield from conn.connect(req) + yield from conn.connect(req, None, ClientTimeout()) assert not conn._waiters @@ -1670,7 +1668,7 @@ async def test_connect_with_limit_concurrent(loop): # Use a real coroutine for _create_connection; a mock would mask # problems that only happen when the method yields. - async def create_connection(req, traces=None): + async def create_connection(req, traces, timeout): nonlocal num_connections num_connections += 1 await asyncio.sleep(0, loop=loop) @@ -1701,7 +1699,7 @@ async def f(start=True): return num_requests += 1 if not start: - connection = await conn.connect(req) + connection = await conn.connect(req, None, ClientTimeout()) await asyncio.sleep(0, loop=loop) connection.release() tasks = [ @@ -1725,7 +1723,7 @@ async def test_connect_waiters_cleanup(loop): conn = aiohttp.BaseConnector(loop=loop, limit=1) conn._available_connections = mock.Mock(return_value=0) - t = loop.create_task(conn.connect(req)) + t = loop.create_task(conn.connect(req, None, ClientTimeout())) await asyncio.sleep(0, loop=loop) assert conn._waiters.keys() @@ -1744,7 +1742,7 @@ async def test_connect_waiters_cleanup_key_error(loop): conn = aiohttp.BaseConnector(loop=loop, limit=1) conn._available_connections = mock.Mock(return_value=0) - t = loop.create_task(conn.connect(req)) + t = loop.create_task(conn.connect(req, None, ClientTimeout())) await asyncio.sleep(0, loop=loop) assert conn._waiters.keys() @@ -1771,7 +1769,7 @@ async def test_close_with_acquired_connection(loop): conn._create_connection.return_value = loop.create_future() conn._create_connection.return_value.set_result(proto) - connection = await conn.connect(req) + connection = await conn.connect(req, None, ClientTimeout()) assert 1 == len(conn._acquired) conn.close() @@ -1840,7 +1838,7 @@ async def test_error_on_connection(loop, key): fut = loop.create_future() exc = OSError() - async def create_connection(req, traces=None): + async def create_connection(req, traces, timeout): nonlocal i i += 1 if i == 1: @@ -1851,8 +1849,8 @@ async def create_connection(req, traces=None): conn._create_connection = create_connection - t1 = loop.create_task(conn.connect(req)) - t2 = loop.create_task(conn.connect(req)) + t1 = loop.create_task(conn.connect(req, None, ClientTimeout())) + t2 = loop.create_task(conn.connect(req, None, ClientTimeout())) await asyncio.sleep(0, loop=loop) assert not t1.done() assert not t2.done() @@ -1885,7 +1883,7 @@ async def create_connection(req, traces=None): conn._acquired.add(proto) - conn2 = loop.create_task(conn.connect(req)) + conn2 = loop.create_task(conn.connect(req, None, ClientTimeout())) await asyncio.sleep(0, loop=loop) conn2.cancel() @@ -1905,7 +1903,7 @@ async def test_error_on_connection_with_cancelled_waiter(loop, key): fut2 = loop.create_future() exc = OSError() - async def create_connection(req, traces=None): + async def create_connection(req, traces, timeout): nonlocal i i += 1 if i == 1: @@ -1918,9 +1916,9 @@ async def create_connection(req, traces=None): conn._create_connection = create_connection - t1 = loop.create_task(conn.connect(req)) - t2 = loop.create_task(conn.connect(req)) - t3 = loop.create_task(conn.connect(req)) + t1 = loop.create_task(conn.connect(req, None, ClientTimeout())) + t2 = loop.create_task(conn.connect(req, None, ClientTimeout())) + t3 = loop.create_task(conn.connect(req, None, ClientTimeout())) await asyncio.sleep(0, loop=loop) assert not t1.done() assert not t2.done() @@ -1963,10 +1961,9 @@ def test_unix_connector_not_found(loop): req = ClientRequest( 'GET', URL('http://www.python.org'), - loop=loop, - ) + loop=loop) with pytest.raises(aiohttp.ClientConnectorError): - loop.run_until_complete(connector.connect(req)) + loop.run_until_complete(connector.connect(req, None, ClientTimeout())) @pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), @@ -1978,10 +1975,9 @@ def test_unix_connector_permission(loop): req = ClientRequest( 'GET', URL('http://www.python.org'), - loop=loop, - ) + loop=loop) with pytest.raises(aiohttp.ClientConnectorError): - loop.run_until_complete(connector.connect(req)) + loop.run_until_complete(connector.connect(req, None, ClientTimeout())) def test_default_use_dns_cache(loop): @@ -1999,7 +1995,7 @@ async def test_resolver_not_called_with_address_is_ip(loop): response_class=mock.Mock()) with pytest.raises(OSError): - await connector.connect(req) + await connector.connect(req, None, ClientTimeout()) resolver.resolve.assert_not_called() diff --git a/tests/test_proxy.py b/tests/test_proxy.py index ccb8e9393eb..22698068777 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -39,7 +39,7 @@ def test_connect(self, ClientRequestMock): req = ClientRequest( 'GET', URL('http://www.python.org'), proxy=URL('http://proxy.example.com'), - loop=self.loop + loop=self.loop, ) self.assertEqual(str(req.proxy), 'http://proxy.example.com') @@ -52,7 +52,8 @@ def test_connect(self, ClientRequestMock): }) self.loop.create_connection = make_mocked_coro( (proto.transport, proto)) - conn = self.loop.run_until_complete(connector.connect(req)) + conn = self.loop.run_until_complete( + connector.connect(req, None, aiohttp.ClientTimeout())) self.assertEqual(req.url, URL('http://www.python.org')) self.assertIs(conn._protocol, proto) self.assertIs(conn.transport, proto.transport) @@ -82,7 +83,8 @@ def test_proxy_headers(self, ClientRequestMock): }) self.loop.create_connection = make_mocked_coro( (proto.transport, proto)) - conn = self.loop.run_until_complete(connector.connect(req)) + conn = self.loop.run_until_complete(connector.connect( + req, None, aiohttp.ClientTimeout())) self.assertEqual(req.url, URL('http://www.python.org')) self.assertIs(conn._protocol, proto) self.assertIs(conn.transport, proto.transport) @@ -118,7 +120,8 @@ def test_proxy_dns_error(self): ) expected_headers = dict(req.headers) with self.assertRaises(aiohttp.ClientConnectorError): - self.loop.run_until_complete(connector.connect(req)) + self.loop.run_until_complete(connector.connect( + req, None, aiohttp.ClientTimeout())) self.assertEqual(req.url.path, '/') self.assertEqual(dict(req.headers), expected_headers) @@ -138,7 +141,8 @@ def test_proxy_connection_error(self): loop=self.loop, ) with self.assertRaises(aiohttp.ClientProxyConnectionError): - self.loop.run_until_complete(connector.connect(req)) + self.loop.run_until_complete(connector.connect( + req, None, aiohttp.ClientTimeout())) @mock.patch('aiohttp.connector.ClientRequest') def test_https_connect(self, ClientRequestMock): @@ -151,7 +155,6 @@ def test_https_connect(self, ClientRequestMock): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=self.loop, session=mock.Mock()) @@ -171,7 +174,8 @@ def test_https_connect(self, ClientRequestMock): proxy=URL('http://proxy.example.com'), loop=self.loop, ) - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout())) self.assertEqual(req.url.path, '/') self.assertEqual(proxy_req.method, 'CONNECT') @@ -194,7 +198,6 @@ def test_https_connect_certificate_error(self, ClientRequestMock): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=self.loop, session=mock.Mock()) @@ -230,7 +233,8 @@ def create_connection(*args, **kwargs): loop=self.loop, ) with self.assertRaises(aiohttp.ClientConnectorCertificateError): - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete(connector._create_connection( + req, None, aiohttp.ClientTimeout())) @mock.patch('aiohttp.connector.ClientRequest') def test_https_connect_ssl_error(self, ClientRequestMock): @@ -243,7 +247,6 @@ def test_https_connect_ssl_error(self, ClientRequestMock): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=self.loop, session=mock.Mock()) @@ -279,7 +282,8 @@ def create_connection(*args, **kwargs): loop=self.loop, ) with self.assertRaises(aiohttp.ClientConnectorSSLError): - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete(connector._create_connection( + req, None, aiohttp.ClientTimeout())) @mock.patch('aiohttp.connector.ClientRequest') def test_https_connect_runtime_error(self, ClientRequestMock): @@ -292,7 +296,6 @@ def test_https_connect_runtime_error(self, ClientRequestMock): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=self.loop, session=mock.Mock()) @@ -315,7 +318,8 @@ def test_https_connect_runtime_error(self, ClientRequestMock): ) with self.assertRaisesRegex( RuntimeError, "Transport does not expose socket instance"): - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete(connector._create_connection( + req, None, aiohttp.ClientTimeout())) self.loop.run_until_complete(proxy_req.close()) proxy_resp.close() @@ -332,7 +336,6 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=self.loop, session=mock.Mock()) @@ -356,7 +359,8 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock): ) with self.assertRaisesRegex( aiohttp.ClientHttpProxyError, "400, message='bad request'"): - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete(connector._create_connection( + req, None, aiohttp.ClientTimeout())) self.loop.run_until_complete(proxy_req.close()) proxy_resp.close() @@ -373,7 +377,6 @@ def test_https_connect_resp_start_error(self, ClientRequestMock): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=self.loop, session=mock.Mock()) @@ -396,7 +399,8 @@ def test_https_connect_resp_start_error(self, ClientRequestMock): loop=self.loop, ) with self.assertRaisesRegex(OSError, "error message"): - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete(connector._create_connection( + req, None, aiohttp.ClientTimeout())) @mock.patch('aiohttp.connector.ClientRequest') def test_request_port(self, ClientRequestMock): @@ -418,7 +422,8 @@ def test_request_port(self, ClientRequestMock): proxy=URL('http://proxy.example.com'), loop=self.loop, ) - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete(connector._create_connection( + req, None, aiohttp.ClientTimeout())) self.assertEqual(req.url, URL('http://localhost:1234/path')) def test_proxy_auth_property(self): @@ -447,7 +452,6 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=self.loop, session=mock.Mock()) @@ -467,7 +471,8 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock): proxy=URL('http://proxy.example.com'), loop=self.loop, ) - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete(connector._create_connection( + req, None, aiohttp.ClientTimeout())) self.loop.create_connection.assert_called_with( mock.ANY, @@ -498,7 +503,6 @@ def test_https_auth(self, ClientRequestMock): writer=mock.Mock(), continue100=None, timer=TimerNoop(), - auto_decompress=True, traces=[], loop=self.loop, session=mock.Mock()) @@ -523,7 +527,8 @@ def test_https_auth(self, ClientRequestMock): ) self.assertNotIn('AUTHORIZATION', req.headers) self.assertNotIn('PROXY-AUTHORIZATION', req.headers) - self.loop.run_until_complete(connector._create_connection(req)) + self.loop.run_until_complete( + connector._create_connection(req, None, aiohttp.ClientTimeout())) self.assertEqual(req.url.path, '/') self.assertNotIn('AUTHORIZATION', req.headers) @@ -534,7 +539,7 @@ def test_https_auth(self, ClientRequestMock): connector._resolve_host.assert_called_with( 'proxy.example.com', 80, - traces=None) + traces=mock.ANY) self.loop.run_until_complete(proxy_req.close()) proxy_resp.close()