Skip to content

Commit

Permalink
Bracket IPv6 addresses in the HOST header (aio-libs#3304)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Polyakov committed Sep 30, 2018
1 parent 508adbb commit 85ef52d
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
2 changes: 2 additions & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def update_auto_headers(self, skip_auto_headers):
# add host
if hdrs.HOST not in used_headers:
netloc = self.url.raw_host
if helpers.is_ipv6_address(netloc):
netloc = '[{}]'.format(netloc)
if not self.url.is_default_port():
netloc += ':' + str(self.url.port)
self.headers[hdrs.HOST] = netloc
Expand Down
27 changes: 16 additions & 11 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from pathlib import Path
from types import TracebackType
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator,
List, Mapping, Optional, Tuple, Type, TypeVar, Union, cast)
List, Mapping, Optional, Pattern, Tuple, Type, TypeVar,
Union, cast)
from urllib.parse import quote
from urllib.request import getproxies

Expand Down Expand Up @@ -604,25 +605,29 @@ def __set__(self, inst: Any, value: Any) -> None:
_ipv6_regexb = re.compile(_ipv6_pattern.encode('ascii'), flags=re.IGNORECASE)


def is_ip_address(
host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
def _is_ip_address(
regex: Pattern, regexb: Pattern,
host: Optional[Union[str, bytes, bytearray, memoryview]])-> bool:
if host is None:
return False
if isinstance(host, str):
if _ipv4_regex.match(host) or _ipv6_regex.match(host):
return True
else:
return False
return bool(regex.match(host))
elif isinstance(host, (bytes, bytearray, memoryview)):
if _ipv4_regexb.match(host) or _ipv6_regexb.match(host): # type: ignore # noqa
return True
else:
return False
return bool(regexb.match(host))
else:
raise TypeError("{} [{}] is not a str or bytes"
.format(host, type(host)))


is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)


def is_ip_address(
host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
return is_ipv4_address(host) or is_ipv6_address(host)


_cached_current_datetime = None
_cached_formatted_datetime = None

Expand Down
20 changes: 20 additions & 0 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,26 @@ def test_host_header_explicit_host_with_port(make_request) -> None:
assert req.headers['HOST'] == 'example.com:99'


def test_host_header_ipv4(make_request) -> None:
req = make_request('get', 'http://127.0.0.2')
assert req.headers['HOST'] == '127.0.0.2'


def test_host_header_ipv6(make_request) -> None:
req = make_request('get', 'http://[::2]')
assert req.headers['HOST'] == '[::2]'


def test_host_header_ipv4_with_port(make_request) -> None:
req = make_request('get', 'http://127.0.0.2:99')
assert req.headers['HOST'] == '127.0.0.2:99'


def test_host_header_ipv6_with_port(make_request) -> None:
req = make_request('get', 'http://[::2]:99')
assert req.headers['HOST'] == '[::2]:99'


def test_default_loop(loop) -> None:
asyncio.set_event_loop(loop)
req = ClientRequest('get', URL('http://python.org/'))
Expand Down
13 changes: 12 additions & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,20 @@ def test_is_ip_address_bytes() -> None:
assert not helpers.is_ip_address(b"1200::AB00:1234::2552:7777:1313")


def test_ip_addresses() -> None:
def test_ipv4_addresses() -> None:
ip_addresses = [
'0.0.0.0',
'127.0.0.1',
'255.255.255.255',
]
for address in ip_addresses:
assert helpers.is_ipv4_address(address)
assert not helpers.is_ipv6_address(address)
assert helpers.is_ip_address(address)


def test_ipv6_addresses() -> None:
ip_addresses = [
'0:0:0:0:0:0:0:0',
'FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF',
'00AB:0002:3008:8CFD:00AB:0002:3008:8CFD',
Expand All @@ -405,6 +414,8 @@ def test_ip_addresses() -> None:
'1::1',
]
for address in ip_addresses:
assert not helpers.is_ipv4_address(address)
assert helpers.is_ipv6_address(address)
assert helpers.is_ip_address(address)


Expand Down

0 comments on commit 85ef52d

Please sign in to comment.