-
Notifications
You must be signed in to change notification settings - Fork 543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PYTHON-1369 Extend driver vector support to arbitrary subtypes and fix handling of variable length types (OSS C* 5.0) #1217
Changes from 11 commits
d153c61
69f54b0
72a27ac
fe7a3b5
c295a8c
d0d5983
3534876
a0791b0
f45d7df
daa54f1
07c86bb
3363d16
1cba1b9
44966c9
dcb008f
8442743
f35dcda
4f6bef8
e23e425
e101213
1840a6d
1f7dd90
c053b67
16cef42
f996a7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,8 +48,8 @@ | |
int32_pack, int32_unpack, int64_pack, int64_unpack, | ||
float_pack, float_unpack, double_pack, double_unpack, | ||
varint_pack, varint_unpack, point_be, point_le, | ||
vints_pack, vints_unpack) | ||
from cassandra import util, VectorDeserializationFailure | ||
vints_pack, vints_unpack, uvint_unpack, uvint_pack) | ||
from cassandra import util | ||
|
||
_little_endian_flag = 1 # we always serialize LE | ||
import ipaddress | ||
|
@@ -393,6 +393,9 @@ def cass_parameterized_type(cls, full=False): | |
""" | ||
return cls.cass_parameterized_type_with(cls.subtypes, full=full) | ||
|
||
@classmethod | ||
def serial_size(cls): | ||
return None | ||
|
||
# it's initially named with a _ to avoid registering it as a real type, but | ||
# client programs may want to use the name still for isinstance(), etc | ||
|
@@ -461,7 +464,6 @@ def serialize(uuid, protocol_version): | |
|
||
class BooleanType(_CassandraType): | ||
typename = 'boolean' | ||
serial_size = 1 | ||
|
||
@staticmethod | ||
def deserialize(byts, protocol_version): | ||
|
@@ -471,6 +473,10 @@ def deserialize(byts, protocol_version): | |
def serialize(truth, protocol_version): | ||
return int8_pack(truth) | ||
|
||
@classmethod | ||
def serial_size(cls): | ||
return 1 | ||
|
||
class ByteType(_CassandraType): | ||
typename = 'tinyint' | ||
|
||
|
@@ -501,7 +507,6 @@ def serialize(var, protocol_version): | |
|
||
class FloatType(_CassandraType): | ||
typename = 'float' | ||
serial_size = 4 | ||
|
||
@staticmethod | ||
def deserialize(byts, protocol_version): | ||
|
@@ -511,10 +516,12 @@ def deserialize(byts, protocol_version): | |
def serialize(byts, protocol_version): | ||
return float_pack(byts) | ||
|
||
@classmethod | ||
def serial_size(cls): | ||
return 4 | ||
|
||
class DoubleType(_CassandraType): | ||
typename = 'double' | ||
serial_size = 8 | ||
|
||
@staticmethod | ||
def deserialize(byts, protocol_version): | ||
|
@@ -524,10 +531,12 @@ def deserialize(byts, protocol_version): | |
def serialize(byts, protocol_version): | ||
return double_pack(byts) | ||
|
||
@classmethod | ||
def serial_size(cls): | ||
return 8 | ||
|
||
class LongType(_CassandraType): | ||
typename = 'bigint' | ||
serial_size = 8 | ||
|
||
@staticmethod | ||
def deserialize(byts, protocol_version): | ||
|
@@ -537,10 +546,12 @@ def deserialize(byts, protocol_version): | |
def serialize(byts, protocol_version): | ||
return int64_pack(byts) | ||
|
||
@classmethod | ||
def serial_size(cls): | ||
return 8 | ||
|
||
class Int32Type(_CassandraType): | ||
typename = 'int' | ||
serial_size = 4 | ||
|
||
@staticmethod | ||
def deserialize(byts, protocol_version): | ||
|
@@ -550,6 +561,9 @@ def deserialize(byts, protocol_version): | |
def serialize(byts, protocol_version): | ||
return int32_pack(byts) | ||
|
||
@classmethod | ||
def serial_size(cls): | ||
return 4 | ||
|
||
class IntegerType(_CassandraType): | ||
typename = 'varint' | ||
|
@@ -653,7 +667,6 @@ class TimestampType(DateType): | |
|
||
class TimeUUIDType(DateType): | ||
typename = 'timeuuid' | ||
serial_size = 16 | ||
|
||
def my_timestamp(self): | ||
return util.unix_time_from_uuid1(self.val) | ||
|
@@ -669,6 +682,9 @@ def serialize(timeuuid, protocol_version): | |
except AttributeError: | ||
raise TypeError("Got a non-UUID object for a UUID value") | ||
|
||
@classmethod | ||
def serial_size(cls): | ||
return 16 | ||
|
||
class SimpleDateType(_CassandraType): | ||
typename = 'date' | ||
|
@@ -700,7 +716,6 @@ def serialize(val, protocol_version): | |
|
||
class ShortType(_CassandraType): | ||
typename = 'smallint' | ||
serial_size = 2 | ||
|
||
@staticmethod | ||
def deserialize(byts, protocol_version): | ||
|
@@ -710,10 +725,18 @@ def deserialize(byts, protocol_version): | |
def serialize(byts, protocol_version): | ||
return int16_pack(byts) | ||
|
||
@classmethod | ||
def serial_size(cls): | ||
return 2 | ||
|
||
class TimeType(_CassandraType): | ||
typename = 'time' | ||
serial_size = 8 | ||
# Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as | ||
# variable size... and we have to match what the server expects since the server | ||
# uses that specification to encode data of that type. | ||
#@classmethod | ||
#def serial_size(cls): | ||
# return 8 | ||
|
||
@staticmethod | ||
def deserialize(byts, protocol_version): | ||
|
@@ -1410,6 +1433,11 @@ class VectorType(_CassandraType): | |
vector_size = 0 | ||
subtype = None | ||
|
||
@classmethod | ||
def serial_size(cls): | ||
serialized_size = cls.subtype.serial_size() | ||
return cls.vector_size * serialized_size if serialized_size is not None else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here (as in many other points in this PR) we're following the example of what's done on the server side. |
||
|
||
@classmethod | ||
def apply_parameters(cls, params, names): | ||
assert len(params) == 2 | ||
|
@@ -1419,19 +1447,31 @@ def apply_parameters(cls, params, names): | |
|
||
@classmethod | ||
def deserialize(cls, byts, protocol_version): | ||
serialized_size = getattr(cls.subtype, "serial_size", None) | ||
if not serialized_size: | ||
raise VectorDeserializationFailure("Cannot determine serialized size for vector with subtype %s" % cls.subtype.__name__) | ||
indexes = (serialized_size * x for x in range(0, cls.vector_size)) | ||
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes] | ||
serialized_size = cls.subtype.serial_size() | ||
if serialized_size is not None: | ||
indexes = (serialized_size * x for x in range(0, cls.vector_size)) | ||
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes] | ||
|
||
idx = 0 | ||
rv = [] | ||
while (idx < len(byts)): | ||
size, bytes_read = uvint_unpack(byts[idx:]) | ||
idx += bytes_read | ||
rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version)) | ||
idx += size | ||
return rv | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to throw an error when the length of elements in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually happening (at least for deserialize) with the changes above. I've also recently added support for similar tests in serliaze() (along with tests for all the cases). |
||
|
||
@classmethod | ||
def serialize(cls, v, protocol_version): | ||
buf = io.BytesIO() | ||
serialized_size = cls.subtype.serial_size() | ||
for item in v: | ||
buf.write(cls.subtype.serialize(item, protocol_version)) | ||
item_bytes = cls.subtype.serialize(item, protocol_version) | ||
if serialized_size is None: | ||
buf.write(uvint_pack(len(item_bytes))) | ||
buf.write(item_bytes) | ||
return buf.getvalue() | ||
|
||
@classmethod | ||
def cql_parameterized_type(cls): | ||
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size) | ||
return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,7 +111,6 @@ def vints_unpack(term): # noqa | |
|
||
return tuple(values) | ||
|
||
|
||
def vints_pack(values): | ||
revbytes = bytearray() | ||
values = [int(v) for v in values[::-1]] | ||
|
@@ -143,3 +142,48 @@ def vints_pack(values): | |
|
||
revbytes.reverse() | ||
return bytes(revbytes) | ||
|
||
def uvint_unpack(bytes): | ||
first_byte = bytes[0] | ||
|
||
if (first_byte & 128) == 0: | ||
return (first_byte,1) | ||
|
||
num_extra_bytes = 8 - (~first_byte & 0xff).bit_length() | ||
rv = first_byte & (0xff >> num_extra_bytes) | ||
for idx in range(1,num_extra_bytes + 1): | ||
new_byte = bytes[idx] | ||
rv <<= 8 | ||
rv |= new_byte & 0xff | ||
|
||
return (rv, num_extra_bytes + 1) | ||
|
||
def uvint_pack(val): | ||
rv = bytearray() | ||
if val < 128: | ||
rv.append(val) | ||
else: | ||
v = val | ||
num_extra_bytes = 0 | ||
num_bits = v.bit_length() | ||
# We need to reserve (num_extra_bytes+1) bits in the first byte | ||
# ie. with 1 extra byte, the first byte needs to be something like '10XXXXXX' # 2 bits reserved | ||
# ie. with 8 extra bytes, the first byte needs to be '11111111' # 8 bits reserved | ||
reserved_bits = num_extra_bytes + 1 | ||
while num_bits > (8-(reserved_bits)): | ||
num_extra_bytes += 1 | ||
num_bits -= 8 | ||
reserved_bits = min(num_extra_bytes + 1, 8) | ||
rv.append(v & 0xff) | ||
v >>= 8 | ||
|
||
if num_extra_bytes > 8: | ||
raise ValueError('Value %d is too big and cannot be encoded as vint' % val) | ||
|
||
# We can now store the last bits in the first byte | ||
n = 8 - num_extra_bytes | ||
v |= (0xff >> n << n) | ||
rv.append(abs(v)) | ||
|
||
rv.reverse() | ||
return bytes(rv) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adapted from vints_pack and vints_unpack above. Those functions are different enough (built-in zig-zag encoding + tuples for incoming/outgoing data) that it seemed worthwhile to live with the code duplication here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in PYTHON-1371 to indicate an attempt to decode a vector of an unsupported subtype. With the other changes in this PR this exception is now completely unnecessary.