Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

READY: Socket enhancements #453

Merged
merged 11 commits into from
Apr 30, 2016
4 changes: 2 additions & 2 deletions riak/client/transport.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import contextmanager
from riak.transports.pool import BadResource
from riak.transports.tcp import is_retryable as is_pbc_retryable
from riak.transports.tcp import is_retryable as is_tcp_retryable
from riak.transports.http import is_retryable as is_http_retryable
import threading
from six import PY2
Expand Down Expand Up @@ -162,7 +162,7 @@ def _is_retryable(error):
:type error: Exception
:rtype: boolean
"""
return is_pbc_retryable(error) or is_http_retryable(error)
return is_tcp_retryable(error) or is_http_retryable(error)


def retryable(fn, protocol=None):
Expand Down
10 changes: 7 additions & 3 deletions riak/riak_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ class RiakError(Exception):
"""
Base class for exceptions generated in the Riak API.
"""
def __init__(self, value):
self.value = value
def __init__(self, *args, **kwargs):
super(RiakError, self).__init__(*args, **kwargs)
if len(args) > 0:
self.value = args[0]
else:
self.value = 'unknown'

def __str__(self):
return repr(self.value)
Expand All @@ -34,5 +38,5 @@ class ConflictError(RiakError):
:class:`~riak.riak_object.RiakObject` that has more than one
sibling.
"""
def __init__(self, message="Object in conflict"):
def __init__(self, message='Object in conflict'):
super(ConflictError, self).__init__(message)
19 changes: 12 additions & 7 deletions riak/tests/test_btypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,18 @@ def test_multiget_bucket_types(self):
self.assertEqual(btype, mobj.bucket.bucket_type)

def test_write_once_bucket_type(self):
btype = self.client.bucket_type('write_once')
bucket = btype.bucket(self.bucket_name)

for i in range(100):
obj = bucket.new(self.key_name + str(i))
obj.data = {'id': i}
obj.store()
bt = 'write_once'
skey = 'write_once-init'
btype = self.client.bucket_type(bt)
bucket = btype.bucket(bt)
sobj = bucket.get(skey)
if not sobj.exists:
for i in range(100):
o = bucket.new(self.key_name + str(i))
o.data = {'id': i}
o.store()
o = bucket.new(skey, data={'id': skey})
o.store()

mget = bucket.multiget([self.key_name + str(i) for i in range(100)])
for mobj in mget:
Expand Down
12 changes: 12 additions & 0 deletions riak/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from six import PY2
from threading import Thread
from riak.riak_object import RiakObject
from riak.transports.tcp import TcpTransport
from riak.tests import DUMMY_HTTP_PORT, DUMMY_PB_PORT, RUN_POOL
from riak.tests.base import IntegrationTestBase

Expand All @@ -13,6 +14,17 @@


class ClientTests(IntegrationTestBase, unittest.TestCase):
def test_can_set_tcp_keepalive(self):
if self.protocol == 'pbc':
topts = {'socket_keepalive': True}
c = self.create_client(transport_options=topts)
for i, r in enumerate(c._tcp_pool.resources):
self.assertIsInstance(r, TcpTransport)
self.assertTrue(r._socket_keepalive)
c.close()
else:
pass

def test_uses_client_id_if_given(self):
if self.protocol == 'pbc':
zero_client_id = "\0\0\0\0"
Expand Down
25 changes: 18 additions & 7 deletions riak/tests/test_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,29 @@ def test_string_bucket_name(self):
def test_generate_key(self):
# Ensure that Riak generates a random key when
# the key passed to bucket.new() is None.
bucket = self.client.bucket('random_key_bucket')
existing_keys = bucket.get_keys()
bucket = self.client.bucket(self.bucket_name)
o = bucket.new(None, data={})
self.assertIsNone(o.key)
o.store()
self.assertIsNotNone(o.key)
self.assertNotIn('/', o.key)
self.assertNotIn(o.key, existing_keys)
self.assertEqual(len(bucket.get_keys()), len(existing_keys) + 1)
existing_keys = bucket.get_keys()
self.assertEqual(len(existing_keys), 1)

def maybe_store_keys(self):
skey = 'rkb-init'
bucket = self.client.bucket('random_key_bucket')
sobj = bucket.get(skey)
if sobj.exists:
return
for key in range(1, 1000):
o = bucket.new(None, data={})
o.store()
o = bucket.new(skey, data={})
o.store()

def test_stream_keys(self):
self.maybe_store_keys()
bucket = self.client.bucket('random_key_bucket')
regular_keys = bucket.get_keys()
self.assertNotEqual(len(regular_keys), 0)
Expand All @@ -203,10 +215,8 @@ def test_stream_keys(self):
self.assertEqual(sorted(regular_keys), sorted(streamed_keys))

def test_stream_keys_timeout(self):
self.maybe_store_keys()
bucket = self.client.bucket('random_key_bucket')
for key in range(1, 1000):
o = bucket.new(None, data={})
o.store()
streamed_keys = []
with self.assertRaises(RiakError):
for keylist in self.client.stream_keys(bucket, timeout=1):
Expand All @@ -216,6 +226,7 @@ def test_stream_keys_timeout(self):
streamed_keys += keylist

def test_stream_keys_abort(self):
self.maybe_store_keys()
bucket = self.client.bucket('random_key_bucket')
regular_keys = bucket.get_keys()
self.assertNotEqual(len(regular_keys), 0)
Expand Down
2 changes: 1 addition & 1 deletion riak/transports/tcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def destroy_resource(self, tcp):
def is_retryable(err):
"""
Determines if the given exception is something that is
network/socket-related and should thus cause the PBC connection to
network/socket-related and should thus cause the TCP connection to
close and the operation retried on another node.

:rtype: boolean
Expand Down
41 changes: 33 additions & 8 deletions riak/transports/tcp/connection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import errno
import socket
import struct

Expand All @@ -7,16 +8,23 @@
from riak import RiakError
from riak.codecs.pbuf import PbufCodec
from riak.security import SecurityError, USE_STDLIB_SSL
from riak.transports.pool import BadResource

if not USE_STDLIB_SSL:
from OpenSSL.SSL import Connection
from riak.transports.security import configure_pyopenssl_context
else:
if USE_STDLIB_SSL:
import ssl
from riak.transports.security import configure_ssl_context
else:
from OpenSSL.SSL import Connection
from riak.transports.security import configure_pyopenssl_context


class TcpConnection(object):
# These are set in the TcpTransport initializer
_address = None
_timeout = None
_socket_keepalive = None
_socket_tcp_options = None

"""
Connection-related methods for TcpTransport.
"""
Expand Down Expand Up @@ -174,6 +182,10 @@ def _recv(self, msglen):
toread = msglen
while toread:
nbytes = self._socket.recv_into(view, toread)
# https://docs.python.org/2/howto/sockets.html#using-a-socket
# https://github.com/basho/riak-python-client/issues/399
if nbytes == 0:
raise BadResource('recv_into returned zero bytes unexpectedly')
view = view[nbytes:] # slicing views is cheap
toread -= nbytes
nread += nbytes
Expand All @@ -189,6 +201,13 @@ def _connect(self):
self._timeout)
else:
self._socket = socket.create_connection(self._address)
if self._socket_tcp_options:
ka_opts = self._socket_tcp_options
for k, v in ka_opts.iteritems():
self._socket.setsockopt(socket.SOL_TCP, k, v)
if self._socket_keepalive:
self._socket.setsockopt(
socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if self._client._credentials:
self._init_security()

Expand All @@ -197,9 +216,15 @@ def close(self):
Closes the underlying socket of the PB connection.
"""
if self._socket:
if USE_STDLIB_SSL:
# NB: Python 2.7.8 and earlier does not have a compatible
# shutdown() method due to the SSL lib
try:
self._socket.shutdown(socket.SHUT_RDWR)
except IOError as e:
# NB: sometimes this is the exception if the initial
# connection didn't succeed correctly
if e.errno != errno.EBADF:
raise
self._socket.close()
del self._socket

# These are set in the TcpTransport initializer
_address = None
_timeout = None
7 changes: 6 additions & 1 deletion riak/transports/tcp/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def __init__(self,
self._socket = None
self._pbuf_c = None
self._ttb_c = None
self._use_ttb = kwargs.get('use_ttb', True)
self._socket_tcp_options = \
kwargs.get('socket_tcp_options', {})
self._socket_keepalive = \
kwargs.get('socket_keepalive', False)
self._use_ttb = \
kwargs.get('use_ttb', True)

def _get_pbuf_codec(self):
if not self._pbuf_c:
Expand Down