diff --git a/CHANGES/2620.bugfix b/CHANGES/2620.bugfix new file mode 100644 index 00000000000..e977f40f253 --- /dev/null +++ b/CHANGES/2620.bugfix @@ -0,0 +1 @@ +Fixed race-condition for iterating addresses from the DNSCache. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index f51641c4046..f560583d2c5 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -578,10 +578,9 @@ def clear(self): self._timestamps.clear() def next_addrs(self, host): - # Return an iterator that will get at maximum as many addrs - # there are for the specific host starting from the last - # not itereated addr. - return islice(self._addrs_rr[host], len(self._addrs[host])) + loop = self._addrs_rr[host] + next(loop) + return list(islice(loop, len(self._addrs[host]))) def expired(self, host): if self._ttl is None: diff --git a/tests/test_connector.py b/tests/test_connector.py index 9c39bd46490..05eaf112771 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1959,15 +1959,27 @@ async def test_expired_ttl(self, loop): assert dns_cache_table.expired('localhost') def test_next_addrs(self, dns_cache_table): - dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2']) + dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2', '127.0.0.3']) - # max elements returned are the full list of addrs - addrs = list(dns_cache_table.next_addrs('foo')) - assert addrs == ['127.0.0.1', '127.0.0.2'] - - # different calls to next_addrs return the hosts using + # Each calls to next_addrs return the hosts using # a round robin strategy. addrs = dns_cache_table.next_addrs('foo') - assert next(addrs) == '127.0.0.1' + assert addrs == ['127.0.0.1', '127.0.0.2', '127.0.0.3'] + + addrs = dns_cache_table.next_addrs('foo') + assert addrs == ['127.0.0.2', '127.0.0.3', '127.0.0.1'] + + addrs = dns_cache_table.next_addrs('foo') + assert addrs == ['127.0.0.3', '127.0.0.1', '127.0.0.2'] + + addrs = dns_cache_table.next_addrs('foo') + assert addrs == ['127.0.0.1', '127.0.0.2', '127.0.0.3'] + + def test_next_addrs_single(self, dns_cache_table): + dns_cache_table.add('foo', ['127.0.0.1']) + + addrs = dns_cache_table.next_addrs('foo') + assert addrs == ['127.0.0.1'] + addrs = dns_cache_table.next_addrs('foo') - assert next(addrs) == '127.0.0.2' + assert addrs == ['127.0.0.1']