Skip to content

Commit

Permalink
Support sharding index_location
Browse files Browse the repository at this point in the history
  • Loading branch information
LDeakin committed Nov 4, 2023
1 parent cc200d2 commit 8b29c75
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removes `TransposeOrderImpl`
- `TransposeOrder` is now validated on creation/deserialisation and `TransposeCodec::new` no longer returns a `Result`
- **Breaking**: Change `HexString::as_bytes()` to `as_be_bytes()`
- Support `index_location` for sharding codec

### Fixed
- Bytes codec handling of complex and raw bits data types
Expand Down
57 changes: 57 additions & 0 deletions src/array/codec/array_to_bytes/sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,27 @@ mod tests {
]
}"#;

const JSON_VALID3: &str = r#"{
"chunk_shape": [2, 2],
"codecs": [
{
"name": "bytes",
"configuration": {
"endian": "little"
}
}
],
"index_codecs": [
{
"name": "bytes",
"configuration": {
"endian": "little"
}
}
],
"index_location": "start"
}"#;

fn codec_sharding_round_trip_impl(json: &str, chunk_shape: Vec<u64>) {
let array_representation =
ArrayRepresentation::new(chunk_shape, DataType::UInt16, FillValue::from(0u16)).unwrap();
Expand Down Expand Up @@ -187,6 +208,12 @@ mod tests {
codec_sharding_round_trip_impl(JSON_VALID2, chunk_shape);
}

#[test]
fn codec_sharding_round_trip3() {
let chunk_shape = vec![4, 4];
codec_sharding_round_trip_impl(JSON_VALID3, chunk_shape);
}

#[test]
fn codec_sharding_fill_value() {
let chunk_shape = vec![4, 4];
Expand Down Expand Up @@ -281,4 +308,34 @@ mod tests {
let answer: Vec<u16> = vec![16, 17, 18, 20, 21, 22];
assert_eq!(answer, decoded_partial_chunk);
}

#[test]
fn codec_sharding_partial_decode3() {
let array_representation =
ArrayRepresentation::new(vec![4, 4], DataType::UInt8, FillValue::from(0u8)).unwrap();
let elements: Vec<u8> = (0..array_representation.num_elements() as u8).collect();
let bytes = elements;

let codec_configuration: ShardingCodecConfiguration =
serde_json::from_str(JSON_VALID3).unwrap();
let codec = ShardingCodec::new_with_configuration(&codec_configuration).unwrap();

let encoded = codec.encode(bytes, &array_representation).unwrap();
let decoded_regions = [ArraySubset::new_with_start_shape(vec![1, 0], vec![2, 1]).unwrap()];
let input_handle = Box::new(std::io::Cursor::new(encoded));
let partial_decoder = codec.partial_decoder(input_handle);
let decoded_partial_chunk = partial_decoder
.partial_decode(&array_representation, &decoded_regions)
.unwrap();

let decoded_partial_chunk: Vec<u8> = decoded_partial_chunk
.into_iter()
.flatten()
.collect::<Vec<_>>()
.chunks(std::mem::size_of::<u8>())
.map(|b| u8::from_ne_bytes(b.try_into().unwrap()))
.collect();
let answer: Vec<u8> = vec![4, 8];
assert_eq!(answer, decoded_partial_chunk);
}
}
94 changes: 73 additions & 21 deletions src/array/codec/array_to_bytes/sharding/sharding_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use crate::{

use super::{
calculate_chunks_per_shard, compute_index_encoded_size, decode_shard_index,
sharding_index_decoded_representation, sharding_partial_decoder, ShardingCodecConfiguration,
ShardingCodecConfigurationV1,
sharding_configuration::ShardingIndexLocation, sharding_index_decoded_representation,
sharding_partial_decoder, ShardingCodecConfiguration, ShardingCodecConfigurationV1,
};

use rayon::prelude::*;
Expand Down Expand Up @@ -45,16 +45,24 @@ pub struct ShardingCodec {
inner_codecs: CodecChain,
/// The codecs used to encode and decode the shard index.
index_codecs: CodecChain,
/// Specifies whether the shard index is located at the beginning or end of the file.
index_location: ShardingIndexLocation,
}

impl ShardingCodec {
/// Create a new `sharding` codec.
#[must_use]
pub fn new(chunk_shape: Vec<u64>, inner_codecs: CodecChain, index_codecs: CodecChain) -> Self {
pub fn new(
chunk_shape: Vec<u64>,
inner_codecs: CodecChain,
index_codecs: CodecChain,
index_location: ShardingIndexLocation,
) -> Self {
Self {
chunk_shape,
inner_codecs,
index_codecs,
index_location,
}
}

Expand All @@ -73,6 +81,7 @@ impl ShardingCodec {
configuration.chunk_shape.clone(),
inner_codecs,
index_codecs,
configuration.index_location,
))
}
}
Expand All @@ -83,6 +92,7 @@ impl CodecTraits for ShardingCodec {
chunk_shape: self.chunk_shape.clone(),
codecs: self.inner_codecs.create_metadatas(),
index_codecs: self.index_codecs.create_metadatas(),
index_location: self.index_location,
};
Some(Metadata::new_with_serializable_configuration(IDENTIFIER, &configuration).unwrap())
}
Expand Down Expand Up @@ -126,7 +136,12 @@ impl ArrayCodecTraits for ShardingCodec {

// Iterate over chunk indices
let mut shard_inner_chunks = Vec::new();
let mut encoded_shard_offset: usize = 0;
let index_encoded_size =
compute_index_encoded_size(&self.index_codecs, &index_decoded_representation)?;
let mut encoded_shard_offset = match self.index_location {
ShardingIndexLocation::Start => index_encoded_size,
ShardingIndexLocation::End => 0,
};
for (chunk_index, (_chunk_indices, chunk_subset)) in unsafe {
ArraySubset::new_with_shape(shard_representation.shape().to_vec())
.iter_chunks_unchecked(&self.chunk_shape)
Expand All @@ -146,9 +161,9 @@ impl ArrayCodecTraits for ShardingCodec {
let chunk_encoded = self.inner_codecs.encode(bytes, &chunk_representation)?;

// Append chunk, update array index and offset
shard_index[chunk_index * 2] = encoded_shard_offset.try_into().unwrap();
shard_index[chunk_index * 2] = encoded_shard_offset;
shard_index[chunk_index * 2 + 1] = chunk_encoded.len().try_into().unwrap();
encoded_shard_offset += chunk_encoded.len();
encoded_shard_offset += chunk_encoded.len() as u64;
shard_inner_chunks.push(chunk_encoded);
}
}
Expand All @@ -163,10 +178,20 @@ impl ArrayCodecTraits for ShardingCodec {
let shard_size =
shard_inner_chunks.iter().map(Vec::len).sum::<usize>() + encoded_array_index.len();
let mut shard = Vec::with_capacity(shard_size);
for chunk in shard_inner_chunks {
shard.extend(chunk);
match self.index_location {
ShardingIndexLocation::Start => {
shard.extend(encoded_array_index);
for chunk in shard_inner_chunks {
shard.extend(chunk);
}
}
ShardingIndexLocation::End => {
for chunk in shard_inner_chunks {
shard.extend(chunk);
}
shard.extend(encoded_array_index);
}
}
shard.extend(encoded_array_index);
Ok(shard)
}

Expand Down Expand Up @@ -238,13 +263,18 @@ impl ArrayCodecTraits for ShardingCodec {

// Write the shard index
let index_decoded_representation = sharding_index_decoded_representation(&chunks_per_shard);
let index_encoded_size =
compute_index_encoded_size(&self.index_codecs, &index_decoded_representation)?;
let mut shard_index = vec![u64::MAX; index_decoded_representation.num_elements_usize()];
let mut offset = 0;
let mut encoded_shard_offset = match self.index_location {
ShardingIndexLocation::Start => index_encoded_size,
ShardingIndexLocation::End => 0,
};
for (chunk_index, chunk) in &shard_inner_chunks {
if let Some(chunk) = chunk {
shard_index[chunk_index * 2] = offset.try_into().unwrap();
shard_index[chunk_index * 2] = encoded_shard_offset;
shard_index[chunk_index * 2 + 1] = chunk.len().try_into().unwrap();
offset += chunk.len();
encoded_shard_offset += chunk.len() as u64;
}
}

Expand All @@ -261,13 +291,26 @@ impl ArrayCodecTraits for ShardingCodec {
.sum::<usize>()
+ encoded_array_index.len();
let mut shard = Vec::with_capacity(shard_size);
for chunk in shard_inner_chunks
.into_iter()
.filter_map(|(_, chunk)| chunk)
{
shard.extend(chunk);
match self.index_location {
ShardingIndexLocation::Start => {
shard.extend(encoded_array_index);
for chunk in shard_inner_chunks
.into_iter()
.filter_map(|(_, chunk)| chunk)
{
shard.extend(chunk);
}
}
ShardingIndexLocation::End => {
for chunk in shard_inner_chunks
.into_iter()
.filter_map(|(_, chunk)| chunk)
{
shard.extend(chunk);
}
shard.extend(encoded_array_index);
}
}
shard.extend(encoded_array_index);
Ok(shard)
}

Expand Down Expand Up @@ -309,6 +352,7 @@ impl ArrayToBytesCodecTraits for ShardingCodec {
self.chunk_shape.clone(),
&self.inner_codecs,
&self.index_codecs,
self.index_location,
))
}

Expand Down Expand Up @@ -338,9 +382,17 @@ impl ShardingCodec {
"The encoded shard is smaller than the expected size of its index.".to_string(),
));
}
let encoded_shard_offset =
usize::try_from(encoded_shard.len() as u64 - index_encoded_size).unwrap();
let encoded_shard_index = &encoded_shard[encoded_shard_offset..];

let encoded_shard_index = match self.index_location {
ShardingIndexLocation::Start => {
&encoded_shard[..index_encoded_size.try_into().unwrap()]
}
ShardingIndexLocation::End => {
let encoded_shard_offset =
usize::try_from(encoded_shard.len() as u64 - index_encoded_size).unwrap();
&encoded_shard[encoded_shard_offset..]
}
};

// Decode the shard index
decode_shard_index(
Expand Down
19 changes: 17 additions & 2 deletions src/array/codec/array_to_bytes/sharding/sharding_codec_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::array::codec::{
self, ArrayToArrayCodecTraits, ArrayToBytesCodecTraits, BytesToBytesCodecTraits,
};

use super::ShardingCodec;
use super::{sharding_configuration::ShardingIndexLocation, ShardingCodec};

/// A [`ShardingCodec`] builder.
///
Expand All @@ -20,6 +20,7 @@ pub struct ShardingCodecBuilder {
array_to_array_codecs: Vec<Box<dyn ArrayToArrayCodecTraits>>,
array_to_bytes_codec: Box<dyn ArrayToBytesCodecTraits>,
bytes_to_bytes_codecs: Vec<Box<dyn BytesToBytesCodecTraits>>,
index_location: ShardingIndexLocation,
}

impl ShardingCodecBuilder {
Expand All @@ -36,6 +37,7 @@ impl ShardingCodecBuilder {
array_to_array_codecs: Vec::default(),
array_to_bytes_codec: Box::<codec::BytesCodec>::default(),
bytes_to_bytes_codecs: Vec::default(),
index_location: ShardingIndexLocation::default(),
}
}

Expand Down Expand Up @@ -94,6 +96,14 @@ impl ShardingCodecBuilder {
self
}

/// Set the index location.
///
/// If left unmodified, defaults to the end of the shard.
pub fn index_location(&mut self, index_location: ShardingIndexLocation) -> &mut Self {
self.index_location = index_location;
self
}

/// Build into a [`ShardingCodec`].
#[must_use]
pub fn build(&self) -> ShardingCodec {
Expand All @@ -107,6 +117,11 @@ impl ShardingCodecBuilder {
self.index_array_to_bytes_codec.clone(),
self.index_bytes_to_bytes_codecs.clone(),
);
ShardingCodec::new(self.inner_chunk_shape.clone(), inner_codecs, index_codecs)
ShardingCodec::new(
self.inner_chunk_shape.clone(),
inner_codecs,
index_codecs,
self.index_location,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ pub struct ShardingCodecConfigurationV1 {
pub codecs: Vec<Metadata>,
/// A list of codecs to be used for encoding and decoding the shard index.
pub index_codecs: Vec<Metadata>,
/// Specifies whether the shard index is located at the beginning or end of the file.
#[serde(default)]
pub index_location: ShardingIndexLocation,
}

#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug, Display)]
#[serde(rename_all = "lowercase")]
pub enum ShardingIndexLocation {
Start,
End,
}

impl Default for ShardingIndexLocation {
fn default() -> Self {
Self::End
}
}

#[cfg(test)]
Expand Down Expand Up @@ -84,6 +100,8 @@ mod tests {
}
]
}"#;
serde_json::from_str::<ShardingCodecConfiguration>(JSON).unwrap();
let config = serde_json::from_str::<ShardingCodecConfiguration>(JSON).unwrap();
let ShardingCodecConfiguration::V1(config) = config;
assert_eq!(config.index_location, ShardingIndexLocation::End);
}
}
Loading

0 comments on commit 8b29c75

Please sign in to comment.