Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
bug: Add connection timeouts for APNS
Browse files Browse the repository at this point in the history
Closes #1393
  • Loading branch information
jrconlin committed May 15, 2020
1 parent 00aaab4 commit e56414d
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 22 deletions.
71 changes: 58 additions & 13 deletions autopush/router/apns2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
from collections import deque
from decimal import Decimal
import time

import threading

import hyper.tls
from hyper import HTTP20Connection
Expand Down Expand Up @@ -36,13 +39,21 @@ class APNSException(Exception):
pass


class APNSConnection:
def __init__(self, conn):
self._conn = conn
self._last_used = time.time()


class APNSClient(object):
def __init__(self, cert_file, key_file, topic,
alt=False, use_sandbox=False,
max_connections=APNS_MAX_CONNECTIONS,
logger=None, metrics=None,
load_connections=True,
max_retry=2):
max_retry=2,
conn_ttl=30,
reap_sleep=60):
"""Create the APNS client connector.
The cert_file and key_file can be derived from the exported `.p12`
Expand Down Expand Up @@ -86,19 +97,21 @@ def __init__(self, cert_file, key_file, topic,
self.log = logger
self.metrics = metrics
self.topic = topic
self._in_use_connections = 0
self._max_connections = max_connections
self._max_retry = max_retry
self._max_conn_ttl = conn_ttl
self.connections = deque(maxlen=max_connections)
if load_connections:
self.ssl_context = hyper.tls.init_context(cert=(cert_file,
key_file))
self.connections.extendleft((HTTP20Connection(
self.server,
self.port,
ssl_context=self.ssl_context,
force_proto='h2') for x in range(0, max_connections)))
if self.log:
self.log.debug("Starting APNS connection")
self._reap_sleep = reap_sleep
# this is probably wrong.
self._reaper = threading.Thread(target=self._reap, args=())
self._reaper.daemon = True
self._reaper.start()

def send(self, router_token, payload, apns_id,
priority=True, topic=None, exp=None):
Expand Down Expand Up @@ -154,7 +167,6 @@ def send(self, router_token, payload, apns_id,
status_code=502,
response_body="APNS could not process "
"your message {}".format(reason),
log_exception=False
)
break
except (HTTP20Error, IOError):
Expand All @@ -166,19 +178,52 @@ def send(self, router_token, payload, apns_id,
finally:
# Returning a closed connection to the pool is ok.
# hyper will reconnect on .request()
self._return_connection(connection)
if connection:
self._return_connection(connection)

def _get_connection(self):
try:
connection = self.connections.pop()
return connection
except IndexError:
self._in_use_connections += 1
if self.log:
self.log.debug("Got New APNS connection")
if self._in_use_connections > self._max_connections:
raise RouterException(
"Too many APNS requests, increase pool from {}".format(
self._max_connections
),
status_code=503,
response_body="APNS busy, please retry")
try:
connection = self.connections.pop()
if self.log:
self.log.debug("Got existing APNS connection")
return connection._conn
except IndexError:
return HTTP20Connection(
self.server,
self.port,
ssl_context=self.ssl_context,
force_proto='h2')

def _return_connection(self, connection):
self.connections.appendleft(connection)
self._in_use_connections -= 1
if self.log:
self.log.debug("Done with APNS connection")
conn = APNSConnection(connection)
self.connections.appendleft(conn)

def _reap(self):
if self.log:
self.log.debug("Reaping APNS connections")
for connection in self.connections:
geezers = []
if connection._last_used < (time.time() - self._max_conn_ttl):
if self.log:
self.log.debug("Found old connection")
geezers.append(connection)
try:
for geezer in geezers:
connection._conn.close()
self.connections.remove(geezer)
except ValueError:
pass
time.sleep(self._reap_sleep)
5 changes: 4 additions & 1 deletion autopush/router/apnsrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,16 @@ def _route(self, notification, router_data):
apns_client.send(router_token=router_token, payload=payload,
apns_id=apns_id)
success = True
except RouterException:
# Not sure if this is happening, but
raise
except ConnectionError:
self.metrics.increment("notification.bridge.connection.error",
tags=make_tags(
self._base_tags,
application=rel_channel,
reason="connection_error"))
except (HTTP20Error, socket.error):
except (HTTP20Error, IOError, socket.error):
self.metrics.increment("notification.bridge.connection.error",
tags=make_tags(self._base_tags,
application=rel_channel,
Expand Down
34 changes: 26 additions & 8 deletions autopush/tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from autopush.metrics import SinkMetrics
from autopush.router import (APNSRouter, FCMRouter, FCMv1Router, GCMRouter,
WebPushRouter, fcmv1client, gcmclient)
from autopush.router.apns2 import APNSConnection
from autopush.router.interface import IRouter, RouterResponse
from autopush.tests import MockAssist
from autopush.tests.support import test_db
Expand Down Expand Up @@ -101,7 +102,7 @@ def setUp(self, mt, mc):
except IndexError: # pragma nocover
pass
self.router.apns['firefox'].connections.append(
self.mock_connection
APNSConnection(self.mock_connection)
)
self.router.apns['firefox'].log = Mock(spec=Logger)
self.headers = {"content-encoding": "aesgcm",
Expand Down Expand Up @@ -154,8 +155,9 @@ def test_connection_error(self):
def raiser(*args, **kwargs):
raise ConnectionError("oops")

self.router.apns['firefox'].connections[1].request = Mock(
side_effect=raiser)
self.mock_connection.request = Mock(side_effect=raiser)
self.router.apns['firefox'].connections.append(
APNSConnection(self.mock_connection))

with pytest.raises(RouterException) as ex:
yield self.router.route_notification(self.notif, self.router_data)
Expand All @@ -173,8 +175,9 @@ def raiser(*args, **kwargs):
error.errno = socket.errno.EPIPE
raise error

self.router.apns['firefox'].connections[1].request = Mock(
side_effect=raiser)
self.mock_connection.request = Mock(side_effect=raiser)
self.router.apns['firefox'].connections.append(
APNSConnection(self.mock_connection))

with pytest.raises(RouterException) as ex:
yield self.router.route_notification(self.notif, self.router_data)
Expand Down Expand Up @@ -267,7 +270,10 @@ def test_fail_send(self):
def throw(*args, **kwargs):
raise HTTP20Error("oops")

self.router.apns['firefox'].connections[0].request.side_effect = throw
self.mock_connection.request = Mock(side_effect=throw)
self.router.apns['firefox'].connections.append(
APNSConnection(self.mock_connection))

with pytest.raises(RouterException) as ex:
yield self.router.route_notification(self.notif, self.router_data)
assert isinstance(ex.value, RouterException)
Expand All @@ -288,7 +294,10 @@ def throw(*args, **kwargs):
"[SSL: BAD_WRITE_RETRY] bad write retry"
)

self.router.apns['firefox'].connections[0].request.side_effect = throw
self.mock_connection.request = Mock(side_effect=throw)
self.router.apns['firefox'].connections.append(
APNSConnection(self.mock_connection))

with pytest.raises(RouterException) as ex:
yield self.router.route_notification(self.notif, self.router_data)
assert isinstance(ex.value, RouterException)
Expand Down Expand Up @@ -338,11 +347,20 @@ def check_results(result):
assert result.status_code == 201
assert result.logged_status == 200
assert "TTL" in result.headers
assert self.mock_connection.called
assert self.mock_connection.request.called

d.addCallback(check_results)
return d

def test_reaper(self):
self.router.apns['firefox']._max_conn_ttl = 1
self.router.apns['firefox']._reap_sleep = 0
c = self.router.apns['firefox']._get_connection()
self.router.apns['firefox']._return_connection(c)
time.sleep(1)
self.router.apns['firefox']._reap()
assert len(self.router.apns['firefox'].connections) == 0


class GCMRouterTestCase(unittest.TestCase):

Expand Down

0 comments on commit e56414d

Please sign in to comment.