Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auth support for lookupd connections #77

Merged
merged 1 commit into from
Jun 14, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions nsq/async.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,13 @@ class AsyncConn(EventedMixin):
:param user_agent: a string identifying the agent for this client in the spirit of
HTTP (default: ``<client_library_name>/<version>``) (requires nsqd 0.2.25+)

:param authentication_secret: a string passed to nsqauthd when usting nsqd authentication
(requires nsqd 1.0+)
:param auth_secret: a string passed when using nsq auth (requires nsqd 1.0+)
"""
def __init__(self, host, port, timeout=1.0, heartbeat_interval=30, requeue_delay=90,
tls_v1=False, tls_options=None, snappy=False, user_agent=None,
output_buffer_size=16 * 1024, output_buffer_timeout=250, sample_rate=0,
io_loop=None,
authentication_secret=None):
auth_secret=None):
assert isinstance(host, (str, unicode))
assert isinstance(port, int)
assert isinstance(timeout, float)
Expand All @@ -103,7 +102,7 @@ def __init__(self, host, port, timeout=1.0, heartbeat_interval=30, requeue_delay
assert isinstance(output_buffer_size, int) and output_buffer_size >= 0
assert isinstance(output_buffer_timeout, int) and output_buffer_timeout >= 0
assert isinstance(sample_rate, int) and sample_rate >= 0 and sample_rate < 100
assert isinstance(authentication_secret, (str, unicode, None.__class__))
assert isinstance(auth_secret, (str, unicode, None.__class__))
assert tls_v1 and ssl or not tls_v1, \
'tls_v1 requires Python 2.6+ or Python 2.5 w/ pip install ssl'

Expand Down Expand Up @@ -139,7 +138,7 @@ def __init__(self, host, port, timeout=1.0, heartbeat_interval=30, requeue_delay
self.user_agent = 'pynsq/%s' % __version__

self._authentication_required = False # tracking server auth state
self.authentication_secret = authentication_secret
self.auth_secret = auth_secret
super(AsyncConn, self).__init__()

@property
Expand Down Expand Up @@ -315,11 +314,11 @@ def _on_response_continue(self, data, **kwargs):
return

self.off('response', self._on_response_continue)
if self.authentication_secret and self._authentication_required:
if self.auth_secret and self._authentication_required:
self.on('response', self._on_auth_response)
self.trigger('auth', conn=self, data=self.authentication_secret)
self.trigger('auth', conn=self, data=self.auth_secret)
try:
self.send(nsq.auth(self.authentication_secret))
self.send(nsq.auth(self.auth_secret))
except Exception, e:
self.close()
self.trigger('error', conn=self, error=nsq.SendError('Error sending AUTH', e))
Expand Down
37 changes: 36 additions & 1 deletion nsq/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import functools
import urllib
import random
import urlparse
import cgi

try:
import simplejson as json
Expand Down Expand Up @@ -502,7 +504,17 @@ def query_lookupd(self):
"""
endpoint = self.lookupd_http_addresses[self.lookupd_query_index]
self.lookupd_query_index = (self.lookupd_query_index + 1) % len(self.lookupd_http_addresses)
lookupd_url = endpoint + '/lookup?topic=' + urllib.quote(self.topic)

scheme, netloc, path, query, fragment = urlparse.urlsplit(endpoint)

if not path or path == "/":
path = "/lookup"

params = cgi.parse_qs(query)
params['topic'] = self.topic
query = urllib.urlencode(_utf8_params(params), doseq=1)
lookupd_url = urlparse.urlunsplit((scheme, netloc, path, query, fragment))

req = tornado.httpclient.HTTPRequest(lookupd_url, method='GET',
connect_timeout=1, request_timeout=2)
callback = functools.partial(self._finish_query_lookupd, lookupd_url=lookupd_url)
Expand Down Expand Up @@ -633,3 +645,26 @@ def validate_message(self, message):

def preprocess_message(self, message):
return message

def _utf8_params(params):
"""encode a dictionary of URL parameters (including iterables) as utf-8"""
assert isinstance(params, dict)
encoded_params = []
for k, v in params.items():
if v is None:
continue
if isinstance(v, (int, long, float)):
v = str(v)
if isinstance(v, (list, tuple)):
v = [_utf8(x) for x in v]
else:
v = _utf8(v)
encoded_params.append((k, v))
return dict(encoded_params)

def _utf8(s):
"""encode a unicode string as utf-8"""
if isinstance(s, unicode):
return s.encode("utf-8")
assert isinstance(s, str), "_utf8 expected a str, not %r" % type(s)
return s