Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-1369 Extend driver vector support to arbitrary subtypes and fix handling of variable length types (OSS C* 5.0) #1217

Merged
merged 25 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d153c61
Initial commit of unit test
absurdfarce Jul 18, 2024
69f54b0
What appears to be a working test now
absurdfarce Jul 18, 2024
72a27ac
Some test refinements
absurdfarce Jul 18, 2024
fe7a3b5
Seems to be working now
absurdfarce Jul 19, 2024
c295a8c
Added tests for string, map and vector subtype cases
absurdfarce Jul 19, 2024
d0d5983
We have most things working now. Vector of vectors still seems to be…
absurdfarce Jul 22, 2024
3534876
Fix VectorType.cql_parameterized_type() to properly handle vectors of…
absurdfarce Jul 23, 2024
a0791b0
Fix error in test
absurdfarce Jul 23, 2024
f45d7df
Test fixes
absurdfarce Jul 23, 2024
daa54f1
Removing test client. This will eventually come back (in the form of…
absurdfarce Jul 24, 2024
07c86bb
Remove custom exception type added in PYTHON-1371
absurdfarce Jul 24, 2024
3363d16
Allowing user to pass in custom libev includes and libs via env vars
absurdfarce Aug 21, 2024
1cba1b9
Revert "Allowing user to pass in custom libev includes and libs via e…
absurdfarce Aug 21, 2024
44966c9
Merge branch 'master' into python1369
absurdfarce Aug 23, 2024
dcb008f
Initial sketch of what the bones of an integration test might look like
absurdfarce Aug 27, 2024
8442743
Just moving some things around
absurdfarce Aug 28, 2024
f35dcda
Short is (incorrectly) marked as a variable size type on the server side
absurdfarce Aug 28, 2024
4f6bef8
Add support for Decimal types as positional params
absurdfarce Aug 28, 2024
e23e425
Passing test with basic integer and floating point types
absurdfarce Aug 28, 2024
e101213
A few more fixed size types we missed
absurdfarce Aug 28, 2024
1840a6d
Passing test which covers everything except UDTs
absurdfarce Aug 28, 2024
1f7dd90
Testing for UDTs now included
absurdfarce Aug 28, 2024
c053b67
Explicitly throw ValueErrors when deserializing vectors with too litt…
absurdfarce Aug 30, 2024
16cef42
Some minor cleanup of test fn names
absurdfarce Aug 30, 2024
f996a7d
Added checks for vector serialize op (and tests)
absurdfarce Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions cassandra/cqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
int32_pack, int32_unpack, int64_pack, int64_unpack,
float_pack, float_unpack, double_pack, double_unpack,
varint_pack, varint_unpack, point_be, point_le,
vints_pack, vints_unpack)
vints_pack, vints_unpack, uvint_unpack, uvint_pack)
from cassandra import util, VectorDeserializationFailure

_little_endian_flag = 1 # we always serialize LE
Expand Down Expand Up @@ -713,7 +713,10 @@ def serialize(byts, protocol_version):

class TimeType(_CassandraType):
typename = 'time'
serial_size = 8
# Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as
# variable size... and we have to match what the server expects since the server
# uses that specification to encode data of that type.
#serial_size = 8

@staticmethod
def deserialize(byts, protocol_version):
Expand Down Expand Up @@ -1420,18 +1423,30 @@ def apply_parameters(cls, params, names):
@classmethod
def deserialize(cls, byts, protocol_version):
serialized_size = getattr(cls.subtype, "serial_size", None)
if not serialized_size:
raise VectorDeserializationFailure("Cannot determine serialized size for vector with subtype %s" % cls.subtype.__name__)
indexes = (serialized_size * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
if serialized_size is not None:
indexes = (serialized_size * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]

idx = 0
rv = []
while (idx < len(byts)):
size, bytes_read = uvint_unpack(byts[idx:])
idx += bytes_read
rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version))
idx += size
return rv
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to throw an error when the length of elements in the byts does not match the vector dimension definition? Same for serialize.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually happening (at least for deserialize) with the changes above. I've also recently added support for similar tests in serliaze() (along with tests for all the cases).


@classmethod
def serialize(cls, v, protocol_version):
buf = io.BytesIO()
serialized_size = getattr(cls.subtype, "serial_size", None)
for item in v:
buf.write(cls.subtype.serialize(item, protocol_version))
item_bytes = cls.subtype.serialize(item, protocol_version)
if serialized_size is None:
buf.write(uvint_pack(len(item_bytes)))
buf.write(item_bytes)
return buf.getvalue()

@classmethod
def cql_parameterized_type(cls):
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size)
return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size)
46 changes: 45 additions & 1 deletion cassandra/marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def vints_unpack(term): # noqa

return tuple(values)


def vints_pack(values):
revbytes = bytearray()
values = [int(v) for v in values[::-1]]
Expand Down Expand Up @@ -143,3 +142,48 @@ def vints_pack(values):

revbytes.reverse()
return bytes(revbytes)

def uvint_unpack(bytes):
first_byte = bytes[0]

if (first_byte & 128) == 0:
return (first_byte,1)

num_extra_bytes = 8 - (~first_byte & 0xff).bit_length()
rv = first_byte & (0xff >> num_extra_bytes)
for idx in range(1,num_extra_bytes + 1):
new_byte = bytes[idx]
rv <<= 8
rv |= new_byte & 0xff

return (rv, num_extra_bytes + 1)

def uvint_pack(val):
rv = bytearray()
if val < 128:
rv.append(val)
else:
v = val
num_extra_bytes = 0
num_bits = v.bit_length()
# We need to reserve (num_extra_bytes+1) bits in the first byte
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved
# ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved
reserved_bits = num_extra_bytes + 1
while num_bits > (8-(reserved_bits)):
num_extra_bytes += 1
num_bits -= 8
reserved_bits = min(num_extra_bytes + 1, 8)
rv.append(v & 0xff)
v >>= 8

if num_extra_bytes > 8:
raise ValueError('Value %d is too big and cannot be encoded as vint' % val)

# We can now store the last bits in the first byte
n = 8 - num_extra_bytes
v |= (0xff >> n << n)
rv.append(abs(v))

rv.reverse()
return bytes(rv)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adapted from vints_pack and vints_unpack above. Those functions are different enough (built-in zig-zag encoding + tuples for incoming/outgoing data) that it seemed worthwhile to live with the code duplication here.

123 changes: 123 additions & 0 deletions python1369_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import logging
import unittest

from cassandra.cluster import Cluster, Session

class Python1369Test(unittest.TestCase):

def setUp(self):
#log = logging.getLogger()
#log.setLevel('DEBUG')

#handler = logging.StreamHandler()
#handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
#log.addHandler(handler)

self.cluster = Cluster(['127.0.0.1'])
self.session = self.cluster.connect()
self.session.execute("drop keyspace if exists test")
ks_stmt = """CREATE KEYSPACE test
WITH REPLICATION = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
}"""
self.session.execute(ks_stmt)

def _create_table(self, subtype):
table_stmt = """CREATE TABLE test.foo (
i int PRIMARY KEY,
j vector<%s, 3>
)""" % (subtype,)
self.session.execute(table_stmt)

def _populate_table(self, data):
for k,v in data.items():
self.session.execute("insert into test.foo (i,j) values (%d,%s)" % (k,v))

def _populate_table_prepared(self, data):
ps = self.session.prepare("insert into test.foo (i,j) values (?,?)")
for k,v in data.items():
self.session.execute(ps, [k,v])

def _create_and_populate_table(self, subtype="float", data={}):
self._create_table(subtype)
self._populate_table(data)

def _create_and_populate_table_preapred(self, subtype="float", data={}):
self._create_table(subtype)
self._populate_table_prepared(data)

def _execute_test(self, expected, test_fn):
rs = self.session.execute("select j from test.foo where i = 2")
rows = rs.all()
self.assertEqual(len(rows), 1)
observed = rows[0].j
for idx in range(0, 3):
test_fn(observed[idx], expected[idx])

def test_float_vector(self):
self.session.execute("drop table if exists test.foo")
def test_fn(observed, expected):
self.assertAlmostEqual(observed, expected, places=5)
expected = [1.2, 3.4, 5.6]
data = {1:[8, 2.3, 58], 2:expected, 5:[23, 18, 3.9]}
self._create_and_populate_table(subtype="float", data=data)
self._execute_test(expected, test_fn)

def test_float_vector_prepared(self):
self.session.execute("drop table if exists test.foo")
def test_fn(observed, expected):
self.assertAlmostEqual(observed, expected, places=5)
expected = [1.2, 3.4, 5.6]
data = {1:[8, 2.3, 58], 2:expected, 5:[23, 18, 3.9]}
self._create_and_populate_table_preapred(subtype="float", data=data)
self._execute_test(expected, test_fn)

def test_varint_vector(self):
self.session.execute("drop table if exists test.foo")
def test_fn(observed, expected):
self.assertEqual(observed, expected)
expected=[1, 3, 5]
data = {1:[8, 2, 58], 2:expected, 5:[23, 18, 3]}
self._create_and_populate_table(subtype="varint", data=data)
self._execute_test(expected, test_fn)

def test_varint_vector_prepared(self):
self.session.execute("drop table if exists test.foo")
def test_fn(observed, expected):
self.assertEqual(observed, expected)
expected=[1, 3, 5]
data = {1:[8, 2, 58], 2:expected, 5:[23, 18, 3]}
self._create_and_populate_table_preapred(subtype="varint", data=data)
self._execute_test(expected, test_fn)

def test_string_vector(self):
self.session.execute("drop table if exists test.foo")
def test_fn(observed, expected):
self.assertEqual(observed, expected)
expected=["foo", "bar", "baz"]
data = {1:["a","b","c"], 2:expected, 5:["x","y","z"]}
self._create_and_populate_table(subtype="text", data=data)
self._execute_test(expected, test_fn)

def test_map_vector(self):
self.session.execute("drop table if exists test.foo")
def test_fn(observed, expected):
self.assertEqual(observed, expected)
expected=[{"foo":1}, {"bar":2}, {"baz":3}]
data = {1:[{"a":1},{"b":2},{"c":3}], 2:expected, 5:[{"x":1},{"y":2},{"z":3}]}
self._create_table("map<text,int>")
for k,v in data.items():
self.session.execute("insert into test.foo (i,j) values (%s,%s)", (k,v))
self._execute_test(expected, test_fn)

def test_vector_of_vector(self):
def test_fn(observed, expected):
self.assertEqual(observed, expected)
expected=[[1,2], [4,5], [7,8]]
data = {1:[[10,20], [40,50], [70,80]], 2:expected, 5:[[100,200], [400,500], [700,800]]}
self._create_table("vector<int,2>")
for k,v in data.items():
self.session.execute("insert into test.foo (i,j) values (%s,%s)", (k,v))
self._execute_test(expected, test_fn)
#self.session.execute("drop table test.foo")
88 changes: 63 additions & 25 deletions tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,16 @@ def test_parse_casstype_vector(self):
self.assertEqual(3, ctype.vector_size)
self.assertEqual(FloatType, ctype.subtype)

def test_parse_casstype_vector_of_vectors(self):
inner_type = "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)"
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type))
self.assertTrue(issubclass(ctype, VectorType))
self.assertEqual(3, ctype.vector_size)
sub_ctype = ctype.subtype
self.assertTrue(issubclass(sub_ctype, VectorType))
self.assertEqual(4, sub_ctype.vector_size)
self.assertEqual(FloatType, sub_ctype.subtype)

def test_empty_value(self):
self.assertEqual(str(EmptyValue()), 'EMPTY')

Expand Down Expand Up @@ -309,8 +319,20 @@ def test_cql_quote(self):
self.assertEqual(cql_quote('test'), "'test'")
self.assertEqual(cql_quote(0), '0')

def _round_trip_compare_fn(self, first, second):
if isinstance(first, float):
self.assertAlmostEqual(first, second, places=5)
elif isinstance(first, list) or isinstance(first, set):
for (felem, selem) in zip(first, second):
self.assertAlmostEqual(felem, selem, places=5)
elif isinstance(first, dict):
for ((fk,fv), (sk,sv)) in zip(first.items(), second.items()):
self.assertEqual(fk, sk)
self.assertAlmostEqual(fv, sv, places=5)
else:
self.assertEqual(first,second)

def test_vector_round_trip_types_with_serialized_size(self):
# Test all the types which specify a serialized size... see PYTHON-1371 for details
self._round_trip_test([True, False, False, True], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.BooleanType, 4)")
self._round_trip_test([3.4, 2.9, 41.6, 12.0], \
Expand All @@ -325,41 +347,49 @@ def test_vector_round_trip_types_with_serialized_size(self):
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeUUIDType, 4)")
self._round_trip_test([3, 2, 41, 12], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ShortType, 4)")
self._round_trip_test([datetime.time(1,1,1), datetime.time(2,2,2), datetime.time(3,3,3)], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeType, 3)")

def test_vector_round_trip_types_without_serialized_size(self):
# Test all the types which do not specify a serialized size... see PYTHON-1371 for details
# Varints
with self.assertRaises(VectorDeserializationFailure):
self._round_trip_test([3, 2, 41, 12], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)")
self._round_trip_test([3, 2, 41, 12], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)")
# ASCII text
with self.assertRaises(VectorDeserializationFailure):
self._round_trip_test(["abc", "def", "ghi", "jkl"], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.AsciiType, 4)")
self._round_trip_test(["abc", "def", "ghi", "jkl"], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.AsciiType, 4)")
# UTF8 text
with self.assertRaises(VectorDeserializationFailure):
self._round_trip_test(["abc", "def", "ghi", "jkl"], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 4)")
self._round_trip_test(["abc", "def", "ghi", "jkl"], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 4)")
# Time is something of a weird one. By rights it should be a fixed size type but C* code marks it as variable
# size. We're forced to follow the C* code base (since that's who'll be providing the data we're parsing) so
# we match what they're doing.
self._round_trip_test([datetime.time(1,1,1), datetime.time(2,2,2), datetime.time(3,3,3)], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeType, 3)")
# Duration (containts varints)
with self.assertRaises(VectorDeserializationFailure):
self._round_trip_test([util.Duration(1,1,1), util.Duration(2,2,2), util.Duration(3,3,3)], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DurationType, 3)")
self._round_trip_test([util.Duration(1,1,1), util.Duration(2,2,2), util.Duration(3,3,3)], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DurationType, 3)")
# List (of otherwise serializable type)
with self.assertRaises(VectorDeserializationFailure):
self._round_trip_test([[3.4], [2.9], [41.6], [12.0]], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ListType(org.apache.cassandra.db.marshal.FloatType), 4)")
self._round_trip_test([[3.4], [2.9], [41.6], [12.0]], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ListType \
(org.apache.cassandra.db.marshal.FloatType), 4)")
# Set (of otherwise serializable type)
with self.assertRaises(VectorDeserializationFailure):
self._round_trip_test([set([3.4]), set([2.9]), set([41.6]), set([12.0])], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.FloatType), 4)")
self._round_trip_test([set([3.4]), set([2.9]), set([41.6]), set([12.0])], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType \
(org.apache.cassandra.db.marshal.FloatType), 4)")
# Map (of otherwise serializable types)
with self.assertRaises(VectorDeserializationFailure):
self._round_trip_test([{1:3.4}, {2:2.9}, {3:41.6}, {4:12.0}], \
self._round_trip_test([{1:3.4}, {2:2.9}, {3:41.6}, {4:12.0}], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \
(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.FloatType), 4)")

def test_vector_of_vectors(self):
# Fixed size subytpes of subtypes
self._round_trip_test([[1.2, 3.4], [5.6, 7.8], [9.10, 11.12], [13.14, 15.16]], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType \
(org.apache.cassandra.db.marshal.FloatType,2), 4)")

# subytpes of subtypes without a fixed size
self._round_trip_test([["one", "two"], ["three", "four"], ["five", "six"], ["seven", "eight"]], \
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.VectorType \
(org.apache.cassandra.db.marshal.AsciiType,2), 4)")

def _round_trip_test(self, data, ctype_str):
ctype = parse_casstype_args(ctype_str)
data_bytes = ctype.serialize(data, 0)
Expand All @@ -369,12 +399,20 @@ def _round_trip_test(self, data, ctype_str):
result = ctype.deserialize(data_bytes, 0)
self.assertEqual(len(data), len(result))
for idx in range(0,len(data)):
self.assertAlmostEqual(data[idx], result[idx], places=5)
self._round_trip_compare_fn(data[idx], result[idx])

# parse_casstype_args() is tested above... we're explicitly concerned about cql_parapmeterized_type() output here
def test_vector_cql_parameterized_type(self):
# Base vector functionality
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<float, 4>")

# Test vector-of-vectors
inner_type = "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)"
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type))
inner_parsed_type = "org.apache.cassandra.db.marshal.VectorType<float, 4>"
self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<%s, 3>" % (inner_parsed_type))

ZERO = datetime.timedelta(0)


Expand Down