From 2f277f93a69e62d016cb445069325ce0692b9963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9A=D0=BE=D1=80=D0=B5=D0=BD=D0=B1=D0=B5=D1=80=D0=B3=20?= =?UTF-8?q?=D0=9C=D0=B0=D1=80=D0=BA?= Date: Tue, 26 Dec 2017 14:07:02 +0500 Subject: [PATCH] Fix DNSCache race-condition (#2620) --- CHANGES/2620.bugfix | 1 + aiohttp/connector.py | 9 +++++---- tests/test_connector.py | 28 ++++++++++++++++++++-------- 3 files changed, 26 insertions(+), 12 deletions(-) create mode 100644 CHANGES/2620.bugfix diff --git a/CHANGES/2620.bugfix b/CHANGES/2620.bugfix new file mode 100644 index 00000000000..e977f40f253 --- /dev/null +++ b/CHANGES/2620.bugfix @@ -0,0 +1 @@ +Fixed race-condition for iterating addresses from the DNSCache. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index f51641c4046..e2dd92cc293 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -578,10 +578,11 @@ def clear(self): self._timestamps.clear() def next_addrs(self, host): - # Return an iterator that will get at maximum as many addrs - # there are for the specific host starting from the last - # not itereated addr. - return islice(self._addrs_rr[host], len(self._addrs[host])) + loop = self._addrs_rr[host] + addrs = list(islice(loop, len(self._addrs[host]))) + # Consume one more element to shift internal state of `cycle` + next(loop) + return addrs def expired(self, host): if self._ttl is None: diff --git a/tests/test_connector.py b/tests/test_connector.py index 9c39bd46490..05eaf112771 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1959,15 +1959,27 @@ async def test_expired_ttl(self, loop): assert dns_cache_table.expired('localhost') def test_next_addrs(self, dns_cache_table): - dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2']) + dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2', '127.0.0.3']) - # max elements returned are the full list of addrs - addrs = list(dns_cache_table.next_addrs('foo')) - assert addrs == ['127.0.0.1', '127.0.0.2'] - - # different calls to next_addrs return the hosts using + # Each calls to next_addrs return the hosts using # a round robin strategy. addrs = dns_cache_table.next_addrs('foo') - assert next(addrs) == '127.0.0.1' + assert addrs == ['127.0.0.1', '127.0.0.2', '127.0.0.3'] + + addrs = dns_cache_table.next_addrs('foo') + assert addrs == ['127.0.0.2', '127.0.0.3', '127.0.0.1'] + + addrs = dns_cache_table.next_addrs('foo') + assert addrs == ['127.0.0.3', '127.0.0.1', '127.0.0.2'] + + addrs = dns_cache_table.next_addrs('foo') + assert addrs == ['127.0.0.1', '127.0.0.2', '127.0.0.3'] + + def test_next_addrs_single(self, dns_cache_table): + dns_cache_table.add('foo', ['127.0.0.1']) + + addrs = dns_cache_table.next_addrs('foo') + assert addrs == ['127.0.0.1'] + addrs = dns_cache_table.next_addrs('foo') - assert next(addrs) == '127.0.0.2' + assert addrs == ['127.0.0.1']