From 063993ad0567478357daa0c1b036cd3a41143a36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9A=D0=BE=D1=80=D0=B5=D0=BD=D0=B1=D0=B5=D1=80=D0=B3=20?= =?UTF-8?q?=D0=9C=D0=B0=D1=80=D0=BA?= Date: Mon, 25 Dec 2017 13:11:38 +0500 Subject: [PATCH] Fix DNSCache race-condition (#2620) --- CHANGES/2620.bugfix | 1 + aiohttp/connector.py | 13 ++++--------- tests/test_connector.py | 14 ++++++-------- 3 files changed, 11 insertions(+), 17 deletions(-) create mode 100644 CHANGES/2620.bugfix diff --git a/CHANGES/2620.bugfix b/CHANGES/2620.bugfix new file mode 100644 index 00000000000..2cdf9452afc --- /dev/null +++ b/CHANGES/2620.bugfix @@ -0,0 +1 @@ +DNSCache now returns random shuffle of IP addresses instead of round-robin generator. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index f51641c4046..4205affaeae 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -7,7 +7,7 @@ from contextlib import suppress from hashlib import md5, sha1, sha256 from http.cookies import SimpleCookie -from itertools import cycle, islice +from random import shuffle from time import monotonic from types import MappingProxyType @@ -547,7 +547,6 @@ class _DNSCacheTable: def __init__(self, ttl=None): self._addrs = {} - self._addrs_rr = {} self._timestamps = {} self._ttl = ttl @@ -560,28 +559,24 @@ def addrs(self): def add(self, host, addrs): self._addrs[host] = addrs - self._addrs_rr[host] = cycle(addrs) if self._ttl: self._timestamps[host] = monotonic() def remove(self, host): self._addrs.pop(host, None) - self._addrs_rr.pop(host, None) if self._ttl: self._timestamps.pop(host, None) def clear(self): self._addrs.clear() - self._addrs_rr.clear() self._timestamps.clear() def next_addrs(self, host): - # Return an iterator that will get at maximum as many addrs - # there are for the specific host starting from the last - # not itereated addr. - return islice(self._addrs_rr[host], len(self._addrs[host])) + addrs = self._addrs[host].copy() + shuffle(addrs) + return addrs def expired(self, host): if self._ttl is None: diff --git a/tests/test_connector.py b/tests/test_connector.py index 9c39bd46490..c1146024498 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -5,6 +5,7 @@ import hashlib import os.path import platform +import random import shutil import socket import ssl @@ -1961,13 +1962,10 @@ async def test_expired_ttl(self, loop): def test_next_addrs(self, dns_cache_table): dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2']) - # max elements returned are the full list of addrs - addrs = list(dns_cache_table.next_addrs('foo')) - assert addrs == ['127.0.0.1', '127.0.0.2'] - - # different calls to next_addrs return the hosts using - # a round robin strategy. + random.seed(1) addrs = dns_cache_table.next_addrs('foo') - assert next(addrs) == '127.0.0.1' + assert addrs == ['127.0.0.2', '127.0.0.1'] + + random.seed(5) addrs = dns_cache_table.next_addrs('foo') - assert next(addrs) == '127.0.0.2' + assert addrs == ['127.0.0.1', '127.0.0.2']