Skip to content

Commit

Permalink
Use encode/decode/partial_decode opt methods
Browse files Browse the repository at this point in the history
Also avoid a copy in sharding partial decoder decode_shard_index
  • Loading branch information
LDeakin committed Nov 12, 2023
1 parent e3fbf60 commit 77e5a24
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 131 deletions.
20 changes: 11 additions & 9 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,9 @@ impl<TStorage: ?Sized + ReadableStorageTraits> Array<TStorage> {
}

/// Initialises a partial decoder for the chunk at `chunk_indices` with optional parallelism.
#[doc(hidden)]
///
/// # Errors
/// Returns an [`ArrayError`] if initialisation of the partial decoder fails.
pub fn partial_decoder_opt<'a>(
&'a self,
chunk_indices: &[u64],
Expand Down Expand Up @@ -1146,14 +1148,14 @@ impl<TStorage: ?Sized + WritableStorageTraits> Array<TStorage> {
let storage_transformer = self
.storage_transformers()
.create_writable_transformer(storage_handle);
let chunk_encoded: Vec<u8> = if self.parallel_codecs() {
self.codecs()
.par_encode(chunk_bytes, &chunk_array_representation)
} else {
self.codecs()
.encode(chunk_bytes, &chunk_array_representation)
}
.map_err(ArrayError::CodecError)?;
let chunk_encoded: Vec<u8> = self
.codecs()
.encode_opt(
chunk_bytes,
&chunk_array_representation,
self.parallel_codecs(),
)
.map_err(ArrayError::CodecError)?;
crate::storage::store_chunk(
&*storage_transformer,
self.path(),
Expand Down
45 changes: 12 additions & 33 deletions src/array/codec/array_to_bytes/codec_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,32 +323,21 @@ impl ArrayCodecTraits for CodecChain {
let mut value = decoded_value;
// array->array
for codec in &self.array_to_array {
value = if parallel {
codec.par_encode(value, &decoded_representation)
} else {
codec.encode(value, &decoded_representation)
}?;
value = codec.encode_opt(value, &decoded_representation, parallel)?;
decoded_representation = codec.compute_encoded_size(&decoded_representation)?;
}

// array->bytes
value = if parallel {
self.array_to_bytes
.par_encode(value, &decoded_representation)
} else {
self.array_to_bytes.encode(value, &decoded_representation)
}?;
value = self
.array_to_bytes
.encode_opt(value, &decoded_representation, parallel)?;
let mut decoded_representation = self
.array_to_bytes
.compute_encoded_size(&decoded_representation)?;

// bytes->bytes
for codec in &self.bytes_to_bytes {
value = if parallel {
codec.par_encode(value)
} else {
codec.encode(value)
}?;
value = codec.encode_opt(value, parallel)?;
decoded_representation = codec.compute_encoded_size(&decoded_representation);
}

Expand All @@ -371,32 +360,22 @@ impl ArrayCodecTraits for CodecChain {
self.bytes_to_bytes.iter().rev(),
bytes_representations.iter().rev().skip(1),
) {
encoded_value = if parallel {
codec.par_decode(encoded_value, bytes_representation)
} else {
codec.decode(encoded_value, bytes_representation)
}?;
encoded_value = codec.decode_opt(encoded_value, bytes_representation, parallel)?;
}

// bytes->array
encoded_value = if parallel {
self.array_to_bytes
.par_decode(encoded_value, array_representations.last().unwrap())
} else {
self.array_to_bytes
.decode(encoded_value, array_representations.last().unwrap())
}?;
encoded_value = self.array_to_bytes.decode_opt(
encoded_value,
array_representations.last().unwrap(),
parallel,
)?;

// array->array
for (codec, array_representation) in std::iter::zip(
self.array_to_array.iter().rev(),
array_representations.iter().rev().skip(1),
) {
encoded_value = if parallel {
codec.par_decode(encoded_value, array_representation)
} else {
codec.decode(encoded_value, array_representation)
}?;
encoded_value = codec.decode_opt(encoded_value, array_representation, parallel)?;
}

if encoded_value.len() as u64 != decoded_representation.size() {
Expand Down
9 changes: 3 additions & 6 deletions src/array/codec/array_to_bytes/sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,14 @@ fn compute_index_encoded_size(
}

fn decode_shard_index(
encoded_shard_index: &[u8],
encoded_shard_index: Vec<u8>,
index_array_representation: &ArrayRepresentation,
index_codecs: &dyn ArrayToBytesCodecTraits,
parallel: bool,
) -> Result<Vec<u64>, CodecError> {
// Decode the shard index
let decoded_shard_index = if parallel {
index_codecs.par_decode(encoded_shard_index.to_vec(), index_array_representation)
} else {
index_codecs.decode(encoded_shard_index.to_vec(), index_array_representation)
}?;
let decoded_shard_index =
index_codecs.decode_opt(encoded_shard_index, index_array_representation, parallel)?;
Ok(decoded_shard_index
.chunks_exact(core::mem::size_of::<u64>())
.map(|v| u64::from_ne_bytes(v.try_into().unwrap() /* safe */))
Expand Down
145 changes: 64 additions & 81 deletions src/array/codec/array_to_bytes/sharding/sharding_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,12 @@ impl ArrayCodecTraits for ShardingCodec {
calculate_chunks_per_shard(decoded_representation.shape(), &self.chunk_shape)
.map_err(|e| CodecError::Other(e.to_string()))?;
let shard_index = self.decode_index(&encoded_value, &chunks_per_shard, false)?; // FIXME: par decode index?
let chunks = if parallel {
self.par_decode_chunks(&encoded_value, &shard_index, decoded_representation)?
} else {
self.decode_chunks(&encoded_value, &shard_index, decoded_representation)?
};
let chunks = self.decode_chunks(
&encoded_value,
&shard_index,
decoded_representation,
parallel,
)?;
Ok(chunks)
}
}
Expand Down Expand Up @@ -672,7 +673,7 @@ impl ShardingCodec {

// Decode the shard index
decode_shard_index(
encoded_shard_index,
encoded_shard_index.to_vec(),
&index_array_representation,
&self.index_codecs,
parallel,
Expand All @@ -684,6 +685,7 @@ impl ShardingCodec {
encoded_shard: &[u8],
shard_index: &[u64],
shard_representation: &ArrayRepresentation,
parallel: bool,
) -> Result<Vec<u8>, CodecError> {
// Allocate an array for the output
let mut shard = vec![MaybeUninit::<u8>::uninit(); shard_representation.size_usize()];
Expand All @@ -698,84 +700,68 @@ impl ShardingCodec {
shard_representation.fill_value().clone(),
)
};
let element_size = chunk_representation.element_size() as u64;
for (chunk_index, (_chunk_indices, chunk_subset)) in unsafe {
ArraySubset::new_with_shape(shard_representation.shape().to_vec())
.iter_chunks_unchecked(&self.chunk_shape)
}
.enumerate()
{
// Read the offset/size
let offset = shard_index[chunk_index * 2];
let size = shard_index[chunk_index * 2 + 1];
let decoded_chunk = if offset == u64::MAX && size == u64::MAX {
chunk_representation
.fill_value()
.as_ne_bytes()
.repeat(chunk_representation.num_elements_usize())
} else {
let offset: usize = offset.try_into().unwrap(); // safe
let size: usize = size.try_into().unwrap(); // safe
let encoded_chunk_slice = encoded_shard[offset..offset + size].to_vec();
self.inner_codecs
.decode(encoded_chunk_slice, &chunk_representation)?
};

// Copy to subset of shard
let mut data_idx = 0;
for (index, num_elements) in unsafe {
chunk_subset
.iter_contiguous_linearised_indices_unchecked(shard_representation.shape())
} {
let shard_offset = usize::try_from(index * element_size).unwrap();
let length = usize::try_from(num_elements * element_size).unwrap();
shard_slice[shard_offset..shard_offset + length]
.copy_from_slice(&decoded_chunk[data_idx..data_idx + length]);
data_idx += length;
}
}

#[allow(clippy::transmute_undefined_repr)]
let shard = unsafe { core::mem::transmute(shard) };
Ok(shard)
}

fn par_decode_chunks(
&self,
encoded_shard: &[u8],
shard_index: &[u64],
shard_representation: &ArrayRepresentation,
) -> Result<Vec<u8>, CodecError> {
// Allocate an array for the output
let mut shard = vec![MaybeUninit::<u8>::uninit(); shard_representation.size_usize()];
let shard_slice =
unsafe { std::slice::from_raw_parts_mut(shard.as_mut_ptr().cast::<u8>(), shard.len()) };
let shard_slice = UnsafeCellSlice::new(shard_slice);

let chunk_representation = unsafe {
ArrayRepresentation::new_unchecked(
self.chunk_shape.clone(),
shard_representation.data_type().clone(),
shard_representation.fill_value().clone(),
if parallel {
let chunks_per_shard = calculate_chunks_per_shard(
shard_representation.shape(),
chunk_representation.shape(),
)
};
let chunks_per_shard =
calculate_chunks_per_shard(shard_representation.shape(), chunk_representation.shape())
.map_err(|e| CodecError::Other(e.to_string()))?;
.map_err(|e| CodecError::Other(e.to_string()))?;
let shard_slice = UnsafeCellSlice::new(shard_slice);
(0..chunks_per_shard.iter().product::<u64>())
.into_par_iter()
.try_for_each(|chunk_index| {
let chunk_subset = self.chunk_index_to_subset(chunk_index, &chunks_per_shard);
let chunk_index = usize::try_from(chunk_index).unwrap();
let shard_slice = unsafe { shard_slice.get() };

// Read the offset/size
let offset = shard_index[chunk_index * 2];
let size = shard_index[chunk_index * 2 + 1];
let decoded_chunk = if offset == u64::MAX && size == u64::MAX {
// Can fill values be populated faster than repeat?
chunk_representation
.fill_value()
.as_ne_bytes()
.repeat(chunk_representation.num_elements_usize())
} else {
let offset: usize = offset.try_into().unwrap(); // safe
let size: usize = size.try_into().unwrap(); // safe
let encoded_chunk_slice = encoded_shard[offset..offset + size].to_vec();
// NOTE: Intentionally using single threaded decode, since parallelisation is in the loop
self.inner_codecs
.decode(encoded_chunk_slice, &chunk_representation)?
};

// Decode chunks
(0..chunks_per_shard.iter().product::<u64>())
.into_par_iter()
.try_for_each(|chunk_index| {
let chunk_subset = self.chunk_index_to_subset(chunk_index, &chunks_per_shard);
let chunk_index = usize::try_from(chunk_index).unwrap();
let shard_slice = unsafe { shard_slice.get() };
// Copy to subset of shard
let mut data_idx = 0;
let element_size = chunk_representation.element_size() as u64;
for (index, num_elements) in unsafe {
chunk_subset.iter_contiguous_linearised_indices_unchecked(
shard_representation.shape(),
)
} {
let shard_offset = usize::try_from(index * element_size).unwrap();
let length = usize::try_from(num_elements * element_size).unwrap();
shard_slice[shard_offset..shard_offset + length]
.copy_from_slice(&decoded_chunk[data_idx..data_idx + length]);
data_idx += length;
}

Ok::<_, CodecError>(())
})?;
} else {
let element_size = chunk_representation.element_size() as u64;
for (chunk_index, (_chunk_indices, chunk_subset)) in unsafe {
ArraySubset::new_with_shape(shard_representation.shape().to_vec())
.iter_chunks_unchecked(&self.chunk_shape)
}
.enumerate()
{
// Read the offset/size
let offset = shard_index[chunk_index * 2];
let size = shard_index[chunk_index * 2 + 1];
let decoded_chunk = if offset == u64::MAX && size == u64::MAX {
// Can fill values be populated faster than repeat?
chunk_representation
.fill_value()
.as_ne_bytes()
Expand All @@ -784,14 +770,12 @@ impl ShardingCodec {
let offset: usize = offset.try_into().unwrap(); // safe
let size: usize = size.try_into().unwrap(); // safe
let encoded_chunk_slice = encoded_shard[offset..offset + size].to_vec();
// NOTE: Intentionally using single threaded decode, since parallelisation is in the loop
self.inner_codecs
.decode(encoded_chunk_slice, &chunk_representation)?
};

// Copy to subset of shard
let mut data_idx = 0;
let element_size = chunk_representation.element_size() as u64;
for (index, num_elements) in unsafe {
chunk_subset
.iter_contiguous_linearised_indices_unchecked(shard_representation.shape())
Expand All @@ -802,9 +786,8 @@ impl ShardingCodec {
.copy_from_slice(&decoded_chunk[data_idx..data_idx + length]);
data_idx += length;
}

Ok::<_, CodecError>(())
})?;
}
}

#[allow(clippy::transmute_undefined_repr)]
let shard = unsafe { core::mem::transmute(shard) };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl<'a> ShardingPartialDecoder<'a> {

Ok(match encoded_shard_index {
Some(encoded_shard_index) => Some(decode_shard_index(
&encoded_shard_index,
encoded_shard_index,
&index_array_representation,
index_codecs,
parallel,
Expand Down
2 changes: 1 addition & 1 deletion src/array/codec/array_to_bytes/zfp/zfp_partial_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl ArrayPartialDecoderTraits for ZfpPartialDecoder<'_> {
decoded_regions: &[ArraySubset],
parallel: bool,
) -> Result<Vec<Vec<u8>>, CodecError> {
let encoded_value = self.input_handle.decode()?;
let encoded_value = self.input_handle.decode_opt(parallel)?;
let mut out = Vec::with_capacity(decoded_regions.len());
match encoded_value {
Some(encoded_value) => {
Expand Down

0 comments on commit 77e5a24

Please sign in to comment.