diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index d7d20fb7f6e..a9cd6ceb3dc 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -306,8 +306,9 @@ def make_mocked_request(method, path, headers=None, *, :param path: str, The URL including *PATH INFO* without the host or scheme :type path: str - :param headers: CIMultiDict with all request headers - :type headers: multidict.CIMultiDict + :param headers: mapping containing the headers. Can be anything accepted + by the multidict.CIMultiDict constructor. + :type headers: dict, multidict.CIMultiDict, list of pairs :param version: namedtuple with encoded HTTP version :type version: aiohttp.protocol.HttpVersion @@ -344,7 +345,7 @@ def make_mocked_request(method, path, headers=None, *, closing = True if headers: - hdrs = headers + hdrs = CIMultiDict(headers) raw_hdrs = [ (k.encode('utf-8'), v.encode('utf-8')) for k, v in headers.items()] else: diff --git a/docs/testing.rst b/docs/testing.rst index 9cf452bd970..c54c2697f97 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -205,14 +205,13 @@ conditions that hard to reproduce on real server:: from aiohttp import web from aiohttp.test_utils import make_mocked_request - from multidict import CIMultiDict def handler(request): assert request.headers.get('token') == 'x' return web.Response(body=b'data') def test_handler(): - req = make_mocked_request('GET', '/', headers=CIMultiDict({'token': 'x'})) + req = make_mocked_request('GET', '/', headers={'token': 'x'}) resp = handler(req) assert resp.body == b'data' diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 5c8a897e480..b76e0a9f1d1 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -1,12 +1,13 @@ import asyncio import pytest +from multidict import CIMultiDict, CIMultiDictProxy import aiohttp -from aiohttp import web +from aiohttp import web, web_reqrep from aiohttp.test_utils import (AioHTTPTestCase, TestClient, loop_context, - setup_test_loop, teardown_test_loop, - unittest_run_loop) + make_mocked_request, setup_test_loop, + teardown_test_loop, unittest_run_loop) def _create_example_app(loop): @@ -180,3 +181,13 @@ def test_test_client_methods(method, loop, test_client): def test_test_client_head(loop, test_client): resp = yield from test_client.head("/") assert resp.status == 200 + + +@pytest.mark.parametrize( + "headers", [{'token': 'x'}, CIMultiDict({'token': 'x'}), {}]) +def test_make_mocked_request(headers): + req = make_mocked_request('GET', '/', headers=headers) + assert req.method == "GET" + assert req.path == "/" + assert isinstance(req, web_reqrep.Request) + assert isinstance(req.headers, CIMultiDictProxy)