Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-1369 Extend driver vector support to arbitrary subtypes and fix handling of variable length types (OSS C* 5.0) #1217

Merged
merged 25 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d153c61
Initial commit of unit test
absurdfarce Jul 18, 2024
69f54b0
What appears to be a working test now
absurdfarce Jul 18, 2024
72a27ac
Some test refinements
absurdfarce Jul 18, 2024
fe7a3b5
Seems to be working now
absurdfarce Jul 19, 2024
c295a8c
Added tests for string, map and vector subtype cases
absurdfarce Jul 19, 2024
d0d5983
We have most things working now. Vector of vectors still seems to be…
absurdfarce Jul 22, 2024
3534876
Fix VectorType.cql_parameterized_type() to properly handle vectors of…
absurdfarce Jul 23, 2024
a0791b0
Fix error in test
absurdfarce Jul 23, 2024
f45d7df
Test fixes
absurdfarce Jul 23, 2024
daa54f1
Removing test client. This will eventually come back (in the form of…
absurdfarce Jul 24, 2024
07c86bb
Remove custom exception type added in PYTHON-1371
absurdfarce Jul 24, 2024
3363d16
Allowing user to pass in custom libev includes and libs via env vars
absurdfarce Aug 21, 2024
1cba1b9
Revert "Allowing user to pass in custom libev includes and libs via e…
absurdfarce Aug 21, 2024
44966c9
Merge branch 'master' into python1369
absurdfarce Aug 23, 2024
dcb008f
Initial sketch of what the bones of an integration test might look like
absurdfarce Aug 27, 2024
8442743
Just moving some things around
absurdfarce Aug 28, 2024
f35dcda
Short is (incorrectly) marked as a variable size type on the server side
absurdfarce Aug 28, 2024
4f6bef8
Add support for Decimal types as positional params
absurdfarce Aug 28, 2024
e23e425
Passing test with basic integer and floating point types
absurdfarce Aug 28, 2024
e101213
A few more fixed size types we missed
absurdfarce Aug 28, 2024
1840a6d
Passing test which covers everything except UDTs
absurdfarce Aug 28, 2024
1f7dd90
Testing for UDTs now included
absurdfarce Aug 28, 2024
c053b67
Explicitly throw ValueErrors when deserializing vectors with too litt…
absurdfarce Aug 30, 2024
16cef42
Some minor cleanup of test fn names
absurdfarce Aug 30, 2024
f996a7d
Added checks for vector serialize op (and tests)
absurdfarce Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions cassandra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,9 +744,3 @@ def __init__(self, msg, excs=[]):
if excs:
complete_msg += ("\nThe following exceptions were observed: \n - " + '\n - '.join(str(e) for e in excs))
Exception.__init__(self, complete_msg)

class VectorDeserializationFailure(DriverException):
"""
The driver was unable to deserialize a given vector
"""
pass
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in PYTHON-1371 to indicate an attempt to decode a vector of an unsupported subtype. With the other changes in this PR this exception is now completely unnecessary.

74 changes: 57 additions & 17 deletions cassandra/cqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
int32_pack, int32_unpack, int64_pack, int64_unpack,
float_pack, float_unpack, double_pack, double_unpack,
varint_pack, varint_unpack, point_be, point_le,
vints_pack, vints_unpack)
from cassandra import util, VectorDeserializationFailure
vints_pack, vints_unpack, uvint_unpack, uvint_pack)
from cassandra import util

_little_endian_flag = 1 # we always serialize LE
import ipaddress
Expand Down Expand Up @@ -393,6 +393,9 @@ def cass_parameterized_type(cls, full=False):
"""
return cls.cass_parameterized_type_with(cls.subtypes, full=full)

@classmethod
def serial_size(cls):
return None

# it's initially named with a _ to avoid registering it as a real type, but
# client programs may want to use the name still for isinstance(), etc
Expand Down Expand Up @@ -461,7 +464,6 @@ def serialize(uuid, protocol_version):

class BooleanType(_CassandraType):
typename = 'boolean'
serial_size = 1

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -471,6 +473,10 @@ def deserialize(byts, protocol_version):
def serialize(truth, protocol_version):
return int8_pack(truth)

@classmethod
def serial_size(cls):
return 1

class ByteType(_CassandraType):
typename = 'tinyint'

Expand Down Expand Up @@ -501,7 +507,6 @@ def serialize(var, protocol_version):

class FloatType(_CassandraType):
typename = 'float'
serial_size = 4

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -511,10 +516,12 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return float_pack(byts)

@classmethod
def serial_size(cls):
return 4

class DoubleType(_CassandraType):
typename = 'double'
serial_size = 8

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -524,10 +531,12 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return double_pack(byts)

@classmethod
def serial_size(cls):
return 8

class LongType(_CassandraType):
typename = 'bigint'
serial_size = 8

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -537,10 +546,12 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return int64_pack(byts)

@classmethod
def serial_size(cls):
return 8

class Int32Type(_CassandraType):
typename = 'int'
serial_size = 4

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -550,6 +561,9 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return int32_pack(byts)

@classmethod
def serial_size(cls):
return 4

class IntegerType(_CassandraType):
typename = 'varint'
Expand Down Expand Up @@ -653,7 +667,6 @@ class TimestampType(DateType):

class TimeUUIDType(DateType):
typename = 'timeuuid'
serial_size = 16

def my_timestamp(self):
return util.unix_time_from_uuid1(self.val)
Expand All @@ -669,6 +682,9 @@ def serialize(timeuuid, protocol_version):
except AttributeError:
raise TypeError("Got a non-UUID object for a UUID value")

@classmethod
def serial_size(cls):
return 16

class SimpleDateType(_CassandraType):
typename = 'date'
Expand Down Expand Up @@ -700,7 +716,6 @@ def serialize(val, protocol_version):

class ShortType(_CassandraType):
typename = 'smallint'
serial_size = 2

@staticmethod
def deserialize(byts, protocol_version):
Expand All @@ -710,10 +725,18 @@ def deserialize(byts, protocol_version):
def serialize(byts, protocol_version):
return int16_pack(byts)

@classmethod
def serial_size(cls):
return 2

class TimeType(_CassandraType):
typename = 'time'
serial_size = 8
# Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as
# variable size... and we have to match what the server expects since the server
# uses that specification to encode data of that type.
#@classmethod
#def serial_size(cls):
# return 8

@staticmethod
def deserialize(byts, protocol_version):
Expand Down Expand Up @@ -1410,6 +1433,11 @@ class VectorType(_CassandraType):
vector_size = 0
subtype = None

@classmethod
def serial_size(cls):
serialized_size = cls.subtype.serial_size()
return cls.vector_size * serialized_size if serialized_size is not None else None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here (as in many other points in this PR) we're following the example of what's done on the server side.


@classmethod
def apply_parameters(cls, params, names):
assert len(params) == 2
Expand All @@ -1419,19 +1447,31 @@ def apply_parameters(cls, params, names):

@classmethod
def deserialize(cls, byts, protocol_version):
serialized_size = getattr(cls.subtype, "serial_size", None)
if not serialized_size:
raise VectorDeserializationFailure("Cannot determine serialized size for vector with subtype %s" % cls.subtype.__name__)
indexes = (serialized_size * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
serialized_size = cls.subtype.serial_size()
if serialized_size is not None:
indexes = (serialized_size * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]

idx = 0
rv = []
while (idx < len(byts)):
size, bytes_read = uvint_unpack(byts[idx:])
idx += bytes_read
rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version))
idx += size
return rv
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to throw an error when the length of elements in the byts does not match the vector dimension definition? Same for serialize.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually happening (at least for deserialize) with the changes above. I've also recently added support for similar tests in serliaze() (along with tests for all the cases).


@classmethod
def serialize(cls, v, protocol_version):
buf = io.BytesIO()
serialized_size = cls.subtype.serial_size()
for item in v:
buf.write(cls.subtype.serialize(item, protocol_version))
item_bytes = cls.subtype.serialize(item, protocol_version)
if serialized_size is None:
buf.write(uvint_pack(len(item_bytes)))
buf.write(item_bytes)
return buf.getvalue()

@classmethod
def cql_parameterized_type(cls):
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size)
return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size)
46 changes: 45 additions & 1 deletion cassandra/marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def vints_unpack(term): # noqa

return tuple(values)


def vints_pack(values):
revbytes = bytearray()
values = [int(v) for v in values[::-1]]
Expand Down Expand Up @@ -143,3 +142,48 @@ def vints_pack(values):

revbytes.reverse()
return bytes(revbytes)

def uvint_unpack(bytes):
first_byte = bytes[0]

if (first_byte & 128) == 0:
return (first_byte,1)

num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
rv = first_byte & (0xff >> num_extra_bytes)
for idx in range(1,num_extra_bytes + 1):
new_byte = bytes[idx]
rv <<= 8
rv |= new_byte & 0xff

return (rv, num_extra_bytes + 1)

def uvint_pack(val):
rv = bytearray()
if val < 128:
rv.append(val)
else:
v = val
num_extra_bytes = 0
num_bits = v.bit_length()
# We need to reserve (num_extra_bytes+1) bits in the first byte
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved
# ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved
reserved_bits = num_extra_bytes + 1
while num_bits > (8-(reserved_bits)):
num_extra_bytes += 1
num_bits -= 8
reserved_bits = min(num_extra_bytes + 1, 8)
rv.append(v & 0xff)
v >>= 8

if num_extra_bytes > 8:
raise ValueError('Value %d is too big and cannot be encoded as vint' % val)

# We can now store the last bits in the first byte
n = 8 - num_extra_bytes
v |= (0xff >> n << n)
rv.append(abs(v))

rv.reverse()
return bytes(rv)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adapted from vints_pack and vints_unpack above. Those functions are different enough (built-in zig-zag encoding + tuples for incoming/outgoing data) that it seemed worthwhile to live with the code duplication here.

Loading