Skip to content

Commit

Permalink
Fix DNSCache race-condition (#2620)
Browse files Browse the repository at this point in the history
  • Loading branch information
socketpair committed Dec 25, 2017
1 parent 2c169cb commit 7d4dcf4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
13 changes: 4 additions & 9 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import suppress
from hashlib import md5, sha1, sha256
from http.cookies import SimpleCookie
from itertools import cycle, islice
from random import shuffle
from time import monotonic
from types import MappingProxyType

Expand Down Expand Up @@ -547,7 +547,6 @@ class _DNSCacheTable:

def __init__(self, ttl=None):
self._addrs = {}
self._addrs_rr = {}
self._timestamps = {}
self._ttl = ttl

Expand All @@ -560,28 +559,24 @@ def addrs(self):

def add(self, host, addrs):
self._addrs[host] = addrs
self._addrs_rr[host] = cycle(addrs)

if self._ttl:
self._timestamps[host] = monotonic()

def remove(self, host):
self._addrs.pop(host, None)
self._addrs_rr.pop(host, None)

if self._ttl:
self._timestamps.pop(host, None)

def clear(self):
self._addrs.clear()
self._addrs_rr.clear()
self._timestamps.clear()

def next_addrs(self, host):
# Return an iterator that will get at maximum as many addrs
# there are for the specific host starting from the last
# not itereated addr.
return islice(self._addrs_rr[host], len(self._addrs[host]))
addrs = self._addrs[host].copy()
shuffle(addrs)
return addrs

def expired(self, host):
if self._ttl is None:
Expand Down
14 changes: 6 additions & 8 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hashlib
import os.path
import platform
import random
import shutil
import socket
import ssl
Expand Down Expand Up @@ -1961,13 +1962,10 @@ async def test_expired_ttl(self, loop):
def test_next_addrs(self, dns_cache_table):
dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2'])

# max elements returned are the full list of addrs
addrs = list(dns_cache_table.next_addrs('foo'))
assert addrs == ['127.0.0.1', '127.0.0.2']

# different calls to next_addrs return the hosts using
# a round robin strategy.
random.seed(1)
addrs = dns_cache_table.next_addrs('foo')
assert next(addrs) == '127.0.0.1'
assert addrs == ['127.0.0.2', '127.0.0.1']

random.seed(5)
addrs = dns_cache_table.next_addrs('foo')
assert next(addrs) == '127.0.0.2'
assert addrs == ['127.0.0.1', '127.0.0.2']

0 comments on commit 7d4dcf4

Please sign in to comment.