Skip to content

Commit

Permalink
PYTHON-1369 Extend driver vector support to arbitrary subtypes and fi…
Browse files Browse the repository at this point in the history
…x handling of variable length types (OSS C* 5.0) (#1217)
  • Loading branch information
absurdfarce authored Sep 4, 2024
1 parent d05e9d3 commit c4a808d
Show file tree
Hide file tree
Showing 7 changed files with 504 additions and 72 deletions.
6 changes: 0 additions & 6 deletions cassandra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,9 +744,3 @@ def __init__(self, msg, excs=[]):
if excs:
complete_msg += ("\nThe following exceptions were observed: \n - " + '\n - '.join(str(e) for e in excs))
Exception.__init__(self, complete_msg)

class VectorDeserializationFailure(DriverException):
"""
The driver was unable to deserialize a given vector
"""
pass
97 changes: 79 additions & 18 deletions cassandra/cqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
int32_pack, int32_unpack, int64_pack, int64_unpack,
float_pack, float_unpack, double_pack, double_unpack,
varint_pack, varint_unpack, point_be, point_le,
vints_pack, vints_unpack)
from cassandra import util, VectorDeserializationFailure
vints_pack, vints_unpack, uvint_unpack, uvint_pack)
from cassandra import util

_little_endian_flag = 1 # we always serialize LE
import ipaddress
Expand Down Expand Up @@ -392,6 +392,9 @@ def cass_parameterized_type(cls, full=False):
"""
return cls.cass_parameterized_type_with(cls.subtypes, full=full)

@classmethod
def serial_size(cls):
return None

# it's initially named with a _ to avoid registering it as a real type, but
# client programs may want to use the name still for isinstance(), etc
Expand Down Expand Up @@ -457,10 +460,12 @@ def serialize(uuid, protocol_version):
except AttributeError:
raise TypeError("Got a non-UUID object for a UUID value")

@classmethod
def serial_size(cls):
return 16

class BooleanType(_CassandraType):
typename = 'boolean'
serial_size = 1

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -470,6 +475,10 @@ def deserialize(byts, protocol_version):
def serialize(truth, protocol_version):
return int8_pack(truth)

@classmethod
def serial_size(cls):
return 1

class ByteType(_CassandraType):
typename = 'tinyint'

Expand Down Expand Up @@ -500,7 +509,6 @@ def serialize(var, protocol_version):

class FloatType(_CassandraType):
typename = 'float'
serial_size = 4

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -510,10 +518,12 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return float_pack(byts)

@classmethod
def serial_size(cls):
return 4

class DoubleType(_CassandraType):
typename = 'double'
serial_size = 8

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -523,10 +533,12 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return double_pack(byts)

@classmethod
def serial_size(cls):
return 8

class LongType(_CassandraType):
typename = 'bigint'
serial_size = 8

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -536,10 +548,12 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return int64_pack(byts)

@classmethod
def serial_size(cls):
return 8

class Int32Type(_CassandraType):
typename = 'int'
serial_size = 4

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -549,6 +563,9 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return int32_pack(byts)

@classmethod
def serial_size(cls):
return 4

class IntegerType(_CassandraType):
typename = 'varint'
Expand Down Expand Up @@ -645,14 +662,16 @@ def serialize(v, protocol_version):

return int64_pack(int(timestamp))

@classmethod
def serial_size(cls):
return 8

class TimestampType(DateType):
pass


class TimeUUIDType(DateType):
typename = 'timeuuid'
serial_size = 16

def my_timestamp(self):
return util.unix_time_from_uuid1(self.val)
Expand All @@ -668,6 +687,9 @@ def serialize(timeuuid, protocol_version):
except AttributeError:
raise TypeError("Got a non-UUID object for a UUID value")

@classmethod
def serial_size(cls):
return 16

class SimpleDateType(_CassandraType):
typename = 'date'
Expand Down Expand Up @@ -699,7 +721,6 @@ def serialize(val, protocol_version):

class ShortType(_CassandraType):
typename = 'smallint'
serial_size = 2

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -709,10 +730,14 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return int16_pack(byts)


class TimeType(_CassandraType):
typename = 'time'
serial_size = 8
# Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as
# variable size... and we have to match what the server expects since the server
# uses that specification to encode data of that type.
#@classmethod
#def serial_size(cls):
# return 8

@staticmethod
def deserialize(byts, protocol_version):
Expand Down Expand Up @@ -1409,6 +1434,11 @@ class VectorType(_CassandraType):
vector_size = 0
subtype = None

@classmethod
def serial_size(cls):
serialized_size = cls.subtype.serial_size()
return cls.vector_size * serialized_size if serialized_size is not None else None

@classmethod
def apply_parameters(cls, params, names):
assert len(params) == 2
Expand All @@ -1418,19 +1448,50 @@ def apply_parameters(cls, params, names):

@classmethod
def deserialize(cls, byts, protocol_version):
serialized_size = getattr(cls.subtype, "serial_size", None)
if not serialized_size:
raise VectorDeserializationFailure("Cannot determine serialized size for vector with subtype %s" % cls.subtype.__name__)
indexes = (serialized_size * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
serialized_size = cls.subtype.serial_size()
if serialized_size is not None:
expected_byte_size = serialized_size * cls.vector_size
if len(byts) != expected_byte_size:
raise ValueError(
"Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\
.format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts)))
indexes = (serialized_size * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]

idx = 0
rv = []
while (len(rv) < cls.vector_size):
try:
size, bytes_read = uvint_unpack(byts[idx:])
idx += bytes_read
rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version))
idx += size
except:
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\
.format(len(rv)))

# If we have any additional data in the serialized vector treat that as an error as well
if idx < len(byts):
raise ValueError("Additional bytes remaining after vector deserialization completed")
return rv

@classmethod
def serialize(cls, v, protocol_version):
v_length = len(v)
if cls.vector_size != v_length:
raise ValueError(
"Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}"\
.format(cls.vector_size, cls.subtype.typename, v_length))

serialized_size = cls.subtype.serial_size()
buf = io.BytesIO()
for item in v:
buf.write(cls.subtype.serialize(item, protocol_version))
item_bytes = cls.subtype.serialize(item, protocol_version)
if serialized_size is None:
buf.write(uvint_pack(len(item_bytes)))
buf.write(item_bytes)
return buf.getvalue()

@classmethod
def cql_parameterized_type(cls):
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size)
return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size)
5 changes: 5 additions & 0 deletions cassandra/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
log = logging.getLogger(__name__)

from binascii import hexlify
from decimal import Decimal
import calendar
import datetime
import math
Expand Down Expand Up @@ -59,6 +60,7 @@ class Encoder(object):
def __init__(self):
self.mapping = {
float: self.cql_encode_float,
Decimal: self.cql_encode_decimal,
bytearray: self.cql_encode_bytes,
str: self.cql_encode_str,
int: self.cql_encode_object,
Expand Down Expand Up @@ -217,3 +219,6 @@ def cql_encode_ipaddress(self, val):
is suitable for ``inet`` type columns.
"""
return "'%s'" % val.compressed

def cql_encode_decimal(self, val):
return self.cql_encode_float(float(val))
46 changes: 45 additions & 1 deletion cassandra/marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def vints_unpack(term): # noqa

return tuple(values)


def vints_pack(values):
revbytes = bytearray()
values = [int(v) for v in values[::-1]]
Expand Down Expand Up @@ -143,3 +142,48 @@ def vints_pack(values):

revbytes.reverse()
return bytes(revbytes)

def uvint_unpack(bytes):
first_byte = bytes[0]

if (first_byte & 128) == 0:
return (first_byte,1)

num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
rv = first_byte & (0xff >> num_extra_bytes)
for idx in range(1,num_extra_bytes + 1):
new_byte = bytes[idx]
rv <<= 8
rv |= new_byte & 0xff

return (rv, num_extra_bytes + 1)

def uvint_pack(val):
rv = bytearray()
if val < 128:
rv.append(val)
else:
v = val
num_extra_bytes = 0
num_bits = v.bit_length()
# We need to reserve (num_extra_bytes+1) bits in the first byte
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved
# ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved
reserved_bits = num_extra_bytes + 1
while num_bits > (8-(reserved_bits)):
num_extra_bytes += 1
num_bits -= 8
reserved_bits = min(num_extra_bytes + 1, 8)
rv.append(v & 0xff)
v >>= 8

if num_extra_bytes > 8:
raise ValueError('Value %d is too big and cannot be encoded as vint' % val)

# We can now store the last bits in the first byte
n = 8 - num_extra_bytes
v |= (0xff >> n << n)
rv.append(abs(v))

rv.reverse()
return bytes(rv)
7 changes: 4 additions & 3 deletions tests/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def _id_and_mark(f):
greaterthanorequalcass36 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.6'), 'Cassandra version 3.6 or greater required')
greaterthanorequalcass3_10 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.10'), 'Cassandra version 3.10 or greater required')
greaterthanorequalcass3_11 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.11'), 'Cassandra version 3.11 or greater required')
greaterthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION >= Version('4.0-a'), 'Cassandra version 4.0 or greater required')
lessthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION <= Version('4.0-a'), 'Cassandra version less or equal to 4.0 required')
lessthancass40 = unittest.skipUnless(CASSANDRA_VERSION < Version('4.0-a'), 'Cassandra version less than 4.0 required')
greaterthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION >= Version('4.0'), 'Cassandra version 4.0 or greater required')
greaterthanorequalcass50 = unittest.skipUnless(CASSANDRA_VERSION >= Version('5.0-beta'), 'Cassandra version 5.0 or greater required')
lessthanorequalcass40 = unittest.skipUnless(CASSANDRA_VERSION <= Version('4.0'), 'Cassandra version less or equal to 4.0 required')
lessthancass40 = unittest.skipUnless(CASSANDRA_VERSION < Version('4.0'), 'Cassandra version less than 4.0 required')
lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < Version('3.0'), 'Cassandra version less then 3.0 required')
greaterthanorequaldse68 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('6.8'), "DSE 6.8 or greater required for this test")
greaterthanorequaldse67 = unittest.skipUnless(DSE_VERSION and DSE_VERSION >= Version('6.7'), "DSE 6.7 or greater required for this test")
Expand Down
Loading

0 comments on commit c4a808d

Please sign in to comment.