diff --git a/src/decrypt.rs b/src/decrypt.rs index 1922fbe10..486d78689 100644 --- a/src/decrypt.rs +++ b/src/decrypt.rs @@ -11,11 +11,9 @@ use bytes::Bytes; use itertools::Itertools; use rayon::prelude::*; use std::io::Cursor; -use std::sync::Arc; use xor_name::XorName; -pub fn decrypt(src_hashes: Vec, encrypted_chunks: Vec) -> Result { - let src_hashes = Arc::new(src_hashes); +pub fn decrypt(src_hashes: Vec, encrypted_chunks: &[&EncryptedChunk]) -> Result { let num_chunks = encrypted_chunks.len(); let cpus = num_cpus::get(); let batch_size = usize::max(1, (num_chunks as f64 / cpus as f64).ceil() as usize); @@ -23,30 +21,21 @@ pub fn decrypt(src_hashes: Vec, encrypted_chunks: Vec) let raw_chunks: Vec<(usize, Bytes)> = encrypted_chunks .chunks(batch_size) .par_bridge() - .map(|batch| DecryptionBatch { - jobs: batch - .iter() - .map(|c| DecryptionJob { - index: c.index, - encrypted_content: c.content.clone(), - src_hashes: src_hashes.clone(), - }) - .collect_vec(), - }) .map(|batch| { - batch - .jobs + let mut decrypted_batch = Vec::with_capacity(batch.len()); + let iter = batch .par_iter() .map(|c| { - Ok::<(usize, Bytes), Error>(( - c.index, - decrypt_chunk(c.index, c.encrypted_content.clone(), c.src_hashes.as_ref())?, - )) + // we can pass &src_hashes since Rayon uses scopes under the hood which guarantees that threads are + // joined before src_hashes goes out of scope + let bytes = decrypt_chunk(c.index, &c.content, &src_hashes)?; + Ok::<(usize, Bytes), Error>((c.index, bytes)) }) - .collect::>() + .flatten(); + decrypted_batch.par_extend(iter); + decrypted_batch }) .flatten() - .flatten() .collect(); if num_chunks > raw_chunks.len() { @@ -66,19 +55,9 @@ pub fn decrypt(src_hashes: Vec, encrypted_chunks: Vec) Ok(raw_data) } -struct DecryptionBatch { - jobs: Vec, -} - -struct DecryptionJob { - index: usize, - encrypted_content: Bytes, - src_hashes: Arc>, -} - pub(crate) fn decrypt_chunk( chunk_number: usize, - content: Bytes, + content: &Bytes, chunk_hashes: &[XorName], ) -> Result { let (pad, key, iv) = get_pad_key_and_iv(chunk_number, chunk_hashes); diff --git a/src/encrypt.rs b/src/encrypt.rs index 82112586d..9f152c488 100644 --- a/src/encrypt.rs +++ b/src/encrypt.rs @@ -102,5 +102,5 @@ pub(crate) fn encrypt_chunk(content: Bytes, pki: (Pad, Key, Iv)) -> Result Result { if encrypted_chunk.index == self.chunk_index { let decrypted_content = - decrypt_chunk(self.chunk_index, encrypted_chunk.content, &self.src_hashes)?; + decrypt_chunk(self.chunk_index, &encrypted_chunk.content, &self.src_hashes)?; self.append_to_file(&decrypted_content)?; self.chunk_index += 1; @@ -344,7 +344,7 @@ impl StreamSelfDecryptor { let _ = chunk_file.read_to_end(&mut chunk_data)?; let decrypted_content = - decrypt_chunk(self.chunk_index, chunk_data.into(), &self.src_hashes)?; + decrypt_chunk(self.chunk_index, &chunk_data.into(), &self.src_hashes)?; self.append_to_file(&decrypted_content)?; self.chunk_index += 1; @@ -426,12 +426,9 @@ pub fn encrypt(bytes: Bytes) -> Result<(DataMap, Vec)> { /// Decrypts what is expected to be the full set of chunks covered by the data map. pub fn decrypt_full_set(data_map: &DataMap, chunks: &[EncryptedChunk]) -> Result { let src_hashes = extract_hashes(data_map); - let sorted_chunks = chunks - .iter() - .sorted_by_key(|c| c.index) - .cloned() // should not be needed, something is wrong here, the docs for sorted_by_key says it will return owned items...! - .collect_vec(); - decrypt::decrypt(src_hashes, sorted_chunks) + let mut sorted_chunks = Vec::with_capacity(chunks.len()); + sorted_chunks.extend(chunks.iter().sorted_by_key(|c| c.index)); + decrypt::decrypt(src_hashes, &sorted_chunks) } /// Decrypts a range, used when seeking. @@ -444,12 +441,10 @@ pub fn decrypt_range( len: usize, ) -> Result { let src_hashes = extract_hashes(data_map); - let encrypted_chunks = chunks - .iter() - .sorted_by_key(|c| c.index) - .cloned() - .collect_vec(); - let mut bytes = decrypt::decrypt(src_hashes, encrypted_chunks)?; + let mut sorted_chunks = Vec::with_capacity(chunks.len()); + sorted_chunks.extend(chunks.iter().sorted_by_key(|c| c.index)); + + let mut bytes = decrypt::decrypt(src_hashes, &sorted_chunks)?; if relative_pos >= bytes.len() { return Ok(Bytes::new()); @@ -463,7 +458,7 @@ pub fn decrypt_range( } /// Helper function to XOR a data with a pad (pad will be rotated to fill the length) -pub(crate) fn xor(data: Bytes, &Pad(pad): &Pad) -> Bytes { +pub(crate) fn xor(data: &Bytes, &Pad(pad): &Pad) -> Bytes { let vec: Vec<_> = data .iter() .zip(pad.iter().cycle())