Skip to content

Commit

Permalink
Merge pull request #896 from jchampio/dev/create_future
Browse files Browse the repository at this point in the history
Use a create_future compatibility wrapper instead of creating Futures directly
  • Loading branch information
asvetlov committed Jun 2, 2016
2 parents 67c7bbf + f2c742f commit 051d36f
Show file tree
Hide file tree
Showing 25 changed files with 149 additions and 108 deletions.
4 changes: 2 additions & 2 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .errors import WSServerHandshakeError
from .websocket import WS_KEY, WebSocketParser, WebSocketWriter
from .websocket_client import ClientWebSocketResponse
from . import hdrs
from . import hdrs, helpers


__all__ = ('ClientSession', 'request', 'get', 'options', 'head',
Expand Down Expand Up @@ -432,7 +432,7 @@ def close(self):
if not self.closed:
self._connector.close()
self._connector = None
ret = asyncio.Future(loop=self._loop)
ret = helpers.create_future(self._loop)
ret.set_result(None)
return ret

Expand Down
2 changes: 1 addition & 1 deletion aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def update_expect_continue(self, expect=False):
expect = True

if expect:
self._continue = asyncio.Future(loop=self.loop)
self._continue = helpers.create_future(self.loop)

@asyncio.coroutine
def write_bytes(self, request, reader):
Expand Down
6 changes: 3 additions & 3 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from math import ceil
from types import MappingProxyType

from . import hdrs
from . import hdrs, helpers
from .client import ClientRequest
from .errors import ServerDisconnectedError
from .errors import HttpProxyError, ProxyConnectionError
Expand Down Expand Up @@ -222,7 +222,7 @@ def _start_cleanup_task(self):

def close(self):
"""Close all opened transports."""
ret = asyncio.Future(loop=self._loop)
ret = helpers.create_future(self._loop)
ret.set_result(None)
if self._closed:
return ret
Expand Down Expand Up @@ -282,7 +282,7 @@ def connect(self, req):

limit = self._limit
if limit is not None:
fut = asyncio.Future(loop=self._loop)
fut = helpers.create_future(self._loop)
waiters = self._waiters[key]

# The limit defines the maximum number of concurrent connections
Expand Down
12 changes: 11 additions & 1 deletion aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
ensure_future = asyncio.async


__all__ = ('BasicAuth', 'FormData', 'parse_mimetype', 'Timeout')
__all__ = ('BasicAuth', 'create_future', 'FormData', 'parse_mimetype',
'Timeout')


class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])):
Expand Down Expand Up @@ -70,6 +71,15 @@ def encode(self):
return 'Basic %s' % base64.b64encode(creds).decode(self.encoding)


def create_future(loop):
"""Compatiblity wrapper for the loop.create_future() call introduced in
3.5.2."""
if hasattr(loop, 'create_future'):
return loop.create_future()
else:
return asyncio.Future(loop=loop)


class FormData:
"""Helper class for multipart/form-data and
application/x-www-form-urlencoded body generation."""
Expand Down
7 changes: 4 additions & 3 deletions aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import traceback

from .log import internal_logger
from . import helpers

__all__ = (
'EofStream', 'StreamReader', 'DataQueue', 'ChunksQueue',
Expand Down Expand Up @@ -150,7 +151,7 @@ def wait_eof(self):
return

assert self._eof_waiter is None
self._eof_waiter = asyncio.Future(loop=self._loop)
self._eof_waiter = helpers.create_future(self._loop)
try:
yield from self._eof_waiter
finally:
Expand Down Expand Up @@ -192,7 +193,7 @@ def _create_waiter(self, func_name):
if self._waiter is not None:
raise RuntimeError('%s() called while another coroutine is '
'already waiting for incoming data' % func_name)
return asyncio.Future(loop=self._loop)
return helpers.create_future(self._loop)

@asyncio.coroutine
def readline(self):
Expand Down Expand Up @@ -441,7 +442,7 @@ def read(self):
raise self._exception

assert not self._waiter
self._waiter = asyncio.Future(loop=self._loop)
self._waiter = helpers.create_future(self._loop)
try:
yield from self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def run(loop, fut):
listen_addr, ssl=sslcontext)
server = thread_loop.run_until_complete(server_coroutine)

waiter = asyncio.Future(loop=thread_loop)
waiter = helpers.create_future(thread_loop)
loop.call_soon_threadsafe(
fut.set_result, (thread_loop, waiter,
server.sockets[0].getsockname()))
Expand All @@ -143,7 +143,7 @@ def run(loop, fut):
thread_loop.close()
gc.collect()

fut = asyncio.Future(loop=loop)
fut = helpers.create_future(loop)
server_thread = threading.Thread(target=run, args=(loop, fut))
server_thread.start()

Expand Down
4 changes: 2 additions & 2 deletions aiohttp/web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from multidict import upstr

from . import hdrs
from . import hdrs, helpers
from .abc import AbstractRouter, AbstractMatchInfo, AbstractView
from .protocol import HttpVersion11
from .web_exceptions import (HTTPMethodNotAllowed, HTTPNotFound,
Expand Down Expand Up @@ -528,7 +528,7 @@ def _sendfile_system(self, req, resp, fobj, count):
loop = req.app.loop
out_fd = transport.get_extra_info("socket").fileno()
in_fd = fobj.fileno()
fut = asyncio.Future(loop=loop)
fut = helpers.create_future(loop)

self._sendfile_cb(fut, out_fd, in_fd, 0, count, loop, False)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_client_functional_oldstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def test_POST_STREAM_DATA(self):
with open(fname, 'rb') as f:
data = f.read()

fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)

@asyncio.coroutine
def stream():
Expand Down
5 changes: 3 additions & 2 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import aiohttp
from aiohttp import BaseConnector
from aiohttp import helpers
from aiohttp.client_reqrep import ClientRequest, ClientResponse

import os.path
Expand Down Expand Up @@ -716,7 +717,7 @@ def test_data_file(self):
self.loop.run_until_complete(req.close())

def test_data_stream_exc(self):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)

def gen():
yield b'binary data'
Expand Down Expand Up @@ -756,7 +757,7 @@ def gen():
resp.close()

def test_data_stream_exc_chain(self):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)

def gen():
yield from fut
Expand Down
23 changes: 12 additions & 11 deletions tests/test_client_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import unittest.mock

import aiohttp
from aiohttp import helpers
from aiohttp.client_reqrep import ClientResponse


Expand Down Expand Up @@ -71,7 +72,7 @@ def test_repr(self):

def test_read_and_release_connection(self):
def side_effect(*args, **kwargs):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)
fut.set_result(b'payload')
return fut
content = self.response.content = unittest.mock.Mock()
Expand All @@ -83,7 +84,7 @@ def side_effect(*args, **kwargs):

def test_read_and_release_connection_with_error(self):
content = self.response.content = unittest.mock.Mock()
content.read.return_value = asyncio.Future(loop=self.loop)
content.read.return_value = helpers.create_future(self.loop)
content.read.return_value.set_exception(ValueError)

self.assertRaises(
Expand All @@ -92,7 +93,7 @@ def test_read_and_release_connection_with_error(self):
self.assertTrue(self.response._closed)

def test_release(self):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)
fut.set_result(b'')
content = self.response.content = unittest.mock.Mock()
content.readany.return_value = fut
Expand All @@ -103,7 +104,7 @@ def test_release(self):
def test_read_decode_deprecated(self):
self.response._content = b'data'
self.response.json = unittest.mock.Mock()
self.response.json.return_value = asyncio.Future(loop=self.loop)
self.response.json.return_value = helpers.create_future(self.loop)
self.response.json.return_value.set_result('json')

with self.assertWarns(DeprecationWarning):
Expand All @@ -113,7 +114,7 @@ def test_read_decode_deprecated(self):

def test_text(self):
def side_effect(*args, **kwargs):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)
fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
return fut
self.response.headers = {
Expand All @@ -127,7 +128,7 @@ def side_effect(*args, **kwargs):

def test_text_custom_encoding(self):
def side_effect(*args, **kwargs):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)
fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
return fut
self.response.headers = {
Expand All @@ -144,7 +145,7 @@ def side_effect(*args, **kwargs):

def test_text_detect_encoding(self):
def side_effect(*args, **kwargs):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)
fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
return fut
self.response.headers = {'CONTENT-TYPE': 'application/json'}
Expand All @@ -158,7 +159,7 @@ def side_effect(*args, **kwargs):

def test_text_after_read(self):
def side_effect(*args, **kwargs):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)
fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
return fut
self.response.headers = {
Expand All @@ -172,7 +173,7 @@ def side_effect(*args, **kwargs):

def test_json(self):
def side_effect(*args, **kwargs):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)
fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
return fut
self.response.headers = {
Expand Down Expand Up @@ -209,7 +210,7 @@ def test_json_no_content(self, m_log):

def test_json_override_encoding(self):
def side_effect(*args, **kwargs):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)
fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
return fut
self.response.headers = {
Expand All @@ -226,7 +227,7 @@ def side_effect(*args, **kwargs):

def test_json_detect_encoding(self):
def side_effect(*args, **kwargs):
fut = asyncio.Future(loop=self.loop)
fut = helpers.create_future(self.loop)
fut.set_result('{"тест": "пройден"}'.encode('cp1251'))
return fut
self.response.headers = {'CONTENT-TYPE': 'application/json'}
Expand Down
21 changes: 11 additions & 10 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import aiohttp
from aiohttp import web
from aiohttp import client
from aiohttp import helpers
from aiohttp.client import ClientResponse
from aiohttp.connector import Connection

Expand Down Expand Up @@ -273,7 +274,7 @@ class Req:
key = ('host', 80, False)
conn._conns[key] = [(tr, proto, self.loop.time())]
conn._create_connection = unittest.mock.Mock()
conn._create_connection.return_value = asyncio.Future(loop=self.loop)
conn._create_connection.return_value = helpers.create_future(self.loop)
conn._create_connection.return_value.set_result((tr, proto))

connection = self.loop.run_until_complete(conn.connect(Req()))
Expand All @@ -286,7 +287,7 @@ class Req:
def test_connect_timeout(self):
conn = aiohttp.BaseConnector(loop=self.loop)
conn._create_connection = unittest.mock.Mock()
conn._create_connection.return_value = asyncio.Future(loop=self.loop)
conn._create_connection.return_value = helpers.create_future(self.loop)
conn._create_connection.return_value.set_exception(
asyncio.TimeoutError())

Expand All @@ -297,7 +298,7 @@ def test_connect_timeout(self):
def test_connect_oserr(self):
conn = aiohttp.BaseConnector(loop=self.loop)
conn._create_connection = unittest.mock.Mock()
conn._create_connection.return_value = asyncio.Future(loop=self.loop)
conn._create_connection.return_value = helpers.create_future(self.loop)
err = OSError(1, 'permission error')
conn._create_connection.return_value.set_exception(err)

Expand Down Expand Up @@ -499,8 +500,8 @@ class Req:
key = ('host', 80, False)
conn._conns[key] = [(tr, proto, self.loop.time())]
conn._create_connection = unittest.mock.Mock()
conn._create_connection.return_value = asyncio.Future(
loop=self.loop)
conn._create_connection.return_value = helpers.create_future(
self.loop)
conn._create_connection.return_value.set_result((tr, proto))

connection1 = yield from conn.connect(Req())
Expand Down Expand Up @@ -547,8 +548,8 @@ class Req:
key = ('host', 80, False)
conn._conns[key] = [(tr, proto, self.loop.time())]
conn._create_connection = unittest.mock.Mock()
conn._create_connection.return_value = asyncio.Future(
loop=self.loop)
conn._create_connection.return_value = helpers.create_future(
self.loop)
conn._create_connection.return_value.set_result((tr, proto))

connection = yield from conn.connect(Req())
Expand All @@ -569,7 +570,7 @@ def check_with_exc(err):
conn = aiohttp.BaseConnector(limit=1, loop=self.loop)
conn._create_connection = unittest.mock.Mock()
conn._create_connection.return_value = \
asyncio.Future(loop=self.loop)
helpers.create_future(self.loop)
conn._create_connection.return_value.set_exception(err)

with self.assertRaises(Exception):
Expand Down Expand Up @@ -668,8 +669,8 @@ class Req:
key = ('host', 80, False)
conn._conns[key] = [(tr, proto, self.loop.time())]
conn._create_connection = unittest.mock.Mock()
conn._create_connection.return_value = asyncio.Future(
loop=self.loop)
conn._create_connection.return_value = helpers.create_future(
self.loop)
conn._create_connection.return_value.set_result((tr, proto))

connection = yield from conn.connect(Req())
Expand Down
23 changes: 23 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,26 @@ def test_requote_uri_properly_requotes():
# Ensure requoting doesn't break expectations.
quoted = 'http://example.com/fiz?buz=%25ppicture'
assert quoted == helpers.requote_uri(quoted)


def test_create_future_with_new_loop():
# We should use the new create_future() if it's available.
mock_loop = mock.Mock()
expected = 'hello'
mock_loop.create_future.return_value = expected
assert expected == helpers.create_future(mock_loop)


@mock.patch('asyncio.Future')
def test_create_future_with_old_loop(MockFuture):
# The old loop (without create_future()) should just have a Future object
# wrapped around it.
mock_loop = mock.Mock()
del mock_loop.create_future

expected = 'hello'
MockFuture.return_value = expected

future = helpers.create_future(mock_loop)
MockFuture.assert_called_with(loop=mock_loop)
assert expected == future
Loading

0 comments on commit 051d36f

Please sign in to comment.