Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsqd Authentication Support #72

Merged
merged 1 commit into from
Jun 7, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions nsq/async.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class AsyncConn(EventedMixin):
* ``error``
* ``identify``
* ``identify_response``
* ``auth``
* ``auth_response``
* ``heartbeat``
* ``ready``
* ``message``
Expand Down Expand Up @@ -83,11 +85,15 @@ class AsyncConn(EventedMixin):

:param user_agent: a string identifying the agent for this client in the spirit of
HTTP (default: ``<client_library_name>/<version>``) (requires nsqd 0.2.25+)

:param authentication_secret: a string passed to nsqauthd when usting nsqd authentication
(requires nsqd 1.0+)
"""
def __init__(self, host, port, timeout=1.0, heartbeat_interval=30, requeue_delay=90,
tls_v1=False, tls_options=None, snappy=False, user_agent=None,
output_buffer_size=16 * 1024, output_buffer_timeout=250, sample_rate=0,
io_loop=None):
io_loop=None,
authentication_secret=None):
assert isinstance(host, (str, unicode))
assert isinstance(port, int)
assert isinstance(timeout, float)
Expand All @@ -97,6 +103,7 @@ def __init__(self, host, port, timeout=1.0, heartbeat_interval=30, requeue_delay
assert isinstance(output_buffer_size, int) and output_buffer_size >= 0
assert isinstance(output_buffer_timeout, int) and output_buffer_timeout >= 0
assert isinstance(sample_rate, int) and sample_rate >= 0 and sample_rate < 100
assert isinstance(authentication_secret, (str, unicode, None.__class__))
assert tls_v1 and ssl or not tls_v1, \
'tls_v1 requires Python 2.6+ or Python 2.5 w/ pip install ssl'

Expand Down Expand Up @@ -130,7 +137,9 @@ def __init__(self, host, port, timeout=1.0, heartbeat_interval=30, requeue_delay

if self.user_agent is None:
self.user_agent = 'pynsq/%s' % __version__


self._authentication_required = False # tracking server auth state
self.authentication_secret = authentication_secret
super(AsyncConn, self).__init__()

@property
Expand Down Expand Up @@ -270,6 +279,7 @@ def _on_identify_response(self, data, **kwargs):
self.off('response', self._on_identify_response)

if data == 'OK':
logging.warning('nsqd version does not support feature netgotiation')
return self.trigger('ready', conn=self)

try:
Expand All @@ -287,6 +297,9 @@ def _on_identify_response(self, data, **kwargs):
self._features_to_enable.append('tls_v1')
if self.snappy and data.get('snappy'):
self._features_to_enable.append('snappy')

if data.get('auth_required'):
self._authentication_required = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two lines can be simplified to self._authentication_required = data.get('auth_required', False)


self.on('response', self._on_response_continue)
self._on_response_continue(conn=self, data=None)
Expand All @@ -298,11 +311,34 @@ def _on_response_continue(self, data, **kwargs):
self.upgrade_to_tls(self.tls_options)
elif feature == 'snappy':
self.upgrade_to_snappy()
# the server will 'OK' after these conneciton upgrades triggering another response
return

self.off('response', self._on_response_continue)
if self.authentication_secret and self._authentication_required:
self.on('response', self._on_auth_response)
self.trigger('auth', conn=self, data=self.authentication_secret)
try:
self.send(nsq.auth(self.authentication_secret))
except Exception, e:
self.close()
self.trigger('error', conn=self, error=nsq.SendError('Error sending AUTH', e))
return
self.trigger('ready', conn=self)

def _on_auth_response(self, data, **kwargs):
try:
data = json.loads(data)
except ValueError:
self.close()
err = 'failed to parse AUTH response JSON from nsqd - %r' % data
self.trigger('error', conn=self, error=nsq.IntegrityError(err))
return

self.off('response', self._on_auth_response)
self.trigger('auth_response', conn=self, data=data)
return self.trigger('ready', conn=self)

def _on_data(self, data, **kwargs):
self.last_recv_timestamp = time.time()
frame, data = nsq.unpack_response(data)
Expand Down
15 changes: 15 additions & 0 deletions nsq/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,22 @@ def _on_connection_identify_response(self, conn, data, **kwargs):
logging.warning('[%s:%s] snappy requested but disabled, could not negotiate feature',
conn.id, self.name)

def _on_connection_auth(self, conn, data, **kwargs):
logging.info('[%s:%s] AUTH sent' % (conn.id, self.name))

def _on_connection_auth_response(self, conn, data, **kwargs):
metadata = []
if data.get('identity'):
metadata.append("Identity: %r" % data['identity'])
if data.get('permission_count'):
metadata.append("Permissions: %d" % data['permission_count'])
if data.get('identity_url'):
metadata.append(data['identity_url'])
logging.info('[%s:%s] AUTH accepted %s' % (conn.id, self.name, ' '.join(metadata)))

def _on_connection_error(self, conn, error, **kwargs):
if kwargs:
logging.error('[%s:%s ERROR: %r]', conn.id, self.name, kwargs)
logging.error('[%s:%s] ERROR: %r', conn.id, self.name, error)

def _check_last_recv_timestamps(self):
Expand Down
2 changes: 2 additions & 0 deletions nsq/nsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def subscribe(topic, channel):
def identify(data):
return _command('IDENTIFY', json.dumps(data))

def auth(data):
return _command('AUTH', data)

def ready(count):
assert isinstance(count, int), 'ready count must be an integer'
Expand Down
2 changes: 2 additions & 0 deletions nsq/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ def connect_to_nsqd(self, host, port):
conn = async.AsyncConn(host, port, **self.conn_kwargs)
conn.on('identify', self._on_connection_identify)
conn.on('identify_response', self._on_connection_identify_response)
conn.on('auth', self._on_connection_auth)
conn.on('auth_response', self._on_connection_auth_response)
conn.on('error', self._on_connection_error)
conn.on('close', self._on_connection_close)
conn.on('ready', self._on_connection_ready)
Expand Down
20 changes: 15 additions & 5 deletions nsq/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, handlers, **settings):

:param \*\*kwargs: passed to :class:`nsq.AsyncConn` initialization
"""
def __init__(self, nsqd_tcp_addresses, name=None, **kwargs):
def __init__(self, nsqd_tcp_addresses, reconnect_interval=15.0, name=None, **kwargs):
super(Writer, self).__init__(**kwargs)

if not isinstance(nsqd_tcp_addresses, (list, set, tuple)):
Expand All @@ -88,6 +88,8 @@ def __init__(self, nsqd_tcp_addresses, name=None, **kwargs):
self.nsqd_tcp_addresses = nsqd_tcp_addresses
self.conns = {}
self.conn_kwargs = kwargs
assert isinstance(reconnect_interval, (int, float))
self.reconnect_interval = reconnect_interval

self.io_loop.add_callback(self._run)

Expand Down Expand Up @@ -123,7 +125,13 @@ def _pub(self, command, topic, msg, callback):
logging.exception('[%s] failed to send %s' % (conn.id, command))
conn.close()

def _on_connection_response(self, conn, data, **kwargs):
def _on_connection_error(self, conn, error, **kwargs):
super(Writer, self)._on_connection_error(conn, error, **kwargs)
while conn.callback_queue:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

callback = conn.callback_queue.pop(0)
callback(conn, error)

def _on_connection_response(self, conn, data=None, **kwargs):
if conn.callback_queue:
callback = conn.callback_queue.pop(0)
callback(conn, data)
Expand All @@ -140,7 +148,9 @@ def connect_to_nsqd(self, host, port):
conn = async.AsyncConn(host, port, **self.conn_kwargs)
conn.on('identify', self._on_connection_identify)
conn.on('identify_response', self._on_connection_identify_response)
conn.on('error', self._on_connection_response)
conn.on('auth', self._on_connection_auth)
conn.on('auth_response', self._on_connection_auth_response)
conn.on('error', self._on_connection_error)
conn.on('response', self._on_connection_response)
conn.on('close', self._on_connection_close)
conn.on('ready', self._on_connection_ready)
Expand Down Expand Up @@ -173,10 +183,10 @@ def _on_connection_close(self, conn, **kwargs):
logging.exception('[%s] uncaught exception in callback', conn.id)

logging.warning('[%s] connection closed', conn.id)
logging.info('[%s] attempting to reconnect in 15s', conn.id)
logging.info('[%s] attempting to reconnect in %0.2fs', conn.id, self.reconnect_interval)
reconnect_callback = functools.partial(self.connect_to_nsqd,
host=conn.host, port=conn.port)
self.io_loop.add_timeout(time.time() + 15, reconnect_callback)
self.io_loop.add_timeout(time.time() + self.reconnect_interval, reconnect_callback)

def _finish_pub(self, conn, data, command, topic, msg):
if isinstance(data, nsq.Error):
Expand Down