From 7956a1a0d7fea9cf50fb8cbd7dcd2621f31877fd Mon Sep 17 00:00:00 2001 From: pfreixes Date: Thu, 21 Jul 2016 12:29:37 +0200 Subject: [PATCH] DRAFT, New TestClientApp implementation A test client imlementation bound to an App to be used skipping the network layer. Rather than use the `make_mocked_request`, it tries to keep the client instance as much as we can tied to the app with its characteristics, perhaps routing. This new aproximation is formulated to be used for testing keeping the interface quite close to that one used by the AioHttp client, tryig to make the life easier to the developer. ```python client = app.test_client() response = yield from client.get("/") assert (yield from response.text()) == "Hello, world" ``` --- aiohttp/test_utils.py | 167 +++++++++++++++++++++++++++++++++++++++ aiohttp/web.py | 4 + tests/test_test_utils.py | 8 ++ 3 files changed, 179 insertions(+) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index bb9e784eef4..abbef605cd2 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -18,6 +18,7 @@ import traceback import urllib.parse import unittest +import chardet from unittest import mock import asyncio @@ -642,3 +643,169 @@ def make_mocked_request(method, path, headers=CIMultiDict(), *, assert req.transport is transport return req + + +class _TestClientTransport: + + def __init__(self): + self._data = b"" + + @property + def data(self): + return self._data + + def write(self, data): + self._data += data + + def drain(self): + return () + + def get_extra_info(self, *args, **kwargs): + return [""] + + @staticmethod + def _mocked_method(*args, **kwargs): + return mock.Mock() + + set_tcp_nodelay = _mocked_method + set_tcp_cork = _mocked_method + + +class _TestClientBuffer: + + def __init__(self, data): + self._data = data + + def readuntil(self, stop, limit=None): + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + + pos = self._data.find(stop) + if pos >= 0: + end = pos + stop_len + size = end + if limit is not None and size > limit: + raise errors.LineLimitExceededParserError( + 'Line is too long.', limit) + + data = self._data[:size] + return data + else: + if limit is not None and len(self._data) > limit: + raise errors.LineLimitExceededParserError( + 'Line is too long.', limit) + + yield + + +class _TestClientResponse: + + def __init__(self, data, loop): + self._data = data + self._loop = loop + + def feed_data(self, raw_response_message, len_raw_data): + self._raw_response_message = raw_response_message + self._len_raw_data = len_raw_data + + def feed_eof(self): + pass + + @property + def version(self): + return self._raw_response_message.version + + @property + def code(self): + return self._raw_response_message.code + + @property + def reason(self): + return self._raw_response_message.reason + + @property + def headers(self): + return self._raw_response_message.headers + + @property + def content(self): + return self._data[self._len_raw_data:] + + def text(self, encoding=None): + if encoding is None: + encoding = self._get_encoding() + f = asyncio.Future(loop=self._loop) + f.set_result(self.content.decode(encoding)) + return f + + def json(self, *, encoding=None, loads=json.loads): + ctype = self.headers.get(hdrs.CONTENT_TYPE, '').lower() + if 'json' not in ctype: + client_logger.warning( + 'Attempt to decode JSON with unexpected mimetype: %s', ctype) + + stripped = self.content.strip() + if not stripped: + return None + + if encoding is None: + encoding = self._get_encoding() + + f = asyncio.Future(loop=self._loop) + f.set_result(loads(stripped.decode(encoding))) + return f + + def _get_encoding(self): + ctype = self.headers.get(hdrs.CONTENT_TYPE, '').lower() + mtype, stype, _, params = helpers.parse_mimetype(ctype) + + encoding = params.get('charset') + if not encoding: + encoding = chardet.detect(self.content)['encoding'] + if not encoding: + encoding = 'utf-8' + + return encoding + + +class TestClientApp: + """ + A test client imlementation bound to an App to be used skipping the + network layer, formulated to be used for testing such as the AioHttp Client + + >>> client = app.test_client() + >>> response = yield from client.get("/") + >>> assert (yield from response.text()) == "Hello, world" + """ + + def __init__(self, app, handler): + self._app = app + self._handler = handler + + @asyncio.coroutine + def get(self, path, version=None, body=None, headers=None, closing=False): + if headers: + raw_hdrs = [ + (k.encode('utf-8'), v.encode('utf-8')) for k, v in hdrs.items() + ] + else: + headers = CIMultiDict() + raw_hdrs = None + + message = RawRequestMessage('GET', path, version or HttpVersion(1, 1), + headers, raw_hdrs, closing, False) + handler = self._handler(self, self._app, self._app.router, + loop=self._app.loop) + transport = _TestClientTransport() + handler.transport = transport + handler.writer = transport + yield from handler.handle_request(message, mock.Mock()) + + response = _TestClientResponse(transport.data, self._app.loop) + response_parser = aiohttp.HttpResponseParser() + yield from response_parser(response, _TestClientBuffer(transport.data)) + return response diff --git a/aiohttp/web.py b/aiohttp/web.py index 753be841371..8358f1ae05d 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -284,6 +284,10 @@ def register_on_finish(self, func, *args, **kwargs): def copy(self): raise NotImplementedError + def test_client(self, handler=RequestHandler): + from .test_utils import TestClientApp + return TestClientApp(self, handler) + def __call__(self): """gunicorn compatibility""" return self diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 002e03950c5..45c78596312 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -165,3 +165,11 @@ def test_test_client_methods(method, loop, test_client): def test_test_client_head(loop, test_client): resp = yield from test_client.head("/") assert resp.status == 200 + + +@pytest.mark.run_loop +@asyncio.coroutine +def test_test_client_app(loop, app): + client = app.test_client() + response = yield from client.get("/") + assert (yield from response.text()) == "Hello, world"