Skip to content

Commit

Permalink
Kvg improve wdl (#46)
Browse files Browse the repository at this point in the history
* Simplified WDL

* Added reference sequence to disk space calculation

* Change correct task's cpus from 4 to 8

* Increase maxRetries for Rescue
  • Loading branch information
kvg authored Dec 2, 2024
1 parent ab1a300 commit bd0a858
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 104 deletions.
2 changes: 1 addition & 1 deletion src/hidive/src/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ fn phase_variants(matrix: &Vec<BTreeMap<usize, (String, u8)>>) -> (Vec<Option<St

let wmec_matrix = WMECData::new(reads, confidences);

let (p1, p2) = skydive::wmec::phase_all(&wmec_matrix, 30, 20);
let (p1, p2) = skydive::wmec::phase_all(&wmec_matrix, 20, 10);

let mut h1 = Vec::new();
let mut h2 = Vec::new();
Expand Down
52 changes: 40 additions & 12 deletions src/hidive/src/correct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ pub fn start(
let cache_path = std::env::temp_dir();
skydive::elog!("Intermediate data will be stored at {:?}.", cache_path);

let model_absolute_path = if model_path.is_absolute() {
model_path.clone()
} else {
std::env::current_dir().unwrap().join(model_path)
};

// Load datasets
let long_read_seq_urls = skydive::parse::parse_file_names(&[long_read_fasta_path.clone()]);

Expand All @@ -45,15 +51,25 @@ pub fn start(
.for_each(|(chrom, start, end, name)| {
let (padded_start, padded_end) = pad_interval(start, end, window);

let mut corrected_reads = HashMap::new();
// let mut corrected_reads_old: HashMap<String, Vec<(Vec<u8>, HashMap<Vec<u8>, f32>)>> = HashMap::new();
// for window_start in (padded_start..padded_end).step_by(window) {

for window_start in (padded_start..padded_end).step_by(window) {
let window_end = window_start + window as u64;
let corrected_reads: HashMap<String, Vec<(Vec<u8>, HashMap<Vec<u8>, f32>)>> = (padded_start..padded_end).step_by(window).collect::<Vec<_>>()
.into_par_iter()
.filter_map(|window_start| {
let window_end = window_start + window as u64;

let locus = HashSet::from([(chrom.clone(), window_start, window_end, name.clone())]);
let locus = HashSet::from([(chrom.clone(), window_start, window_end, name.clone())]);

let r = skydive::stage::stage_data_in_memory(&locus, &long_read_seq_urls, false, &cache_path);
if let Ok(reads) = r {
let r = skydive::stage::stage_data_in_memory(&locus, &long_read_seq_urls, false, &cache_path);

if let Ok(reads) = r {
Some(reads)
} else {
None
}
})
.map(|reads| {
let long_reads: HashMap<String, Vec<u8>> = reads
.into_iter()
.map(|read| (read.id().to_string(), read.seq().to_vec()))
Expand Down Expand Up @@ -88,22 +104,34 @@ pub fn start(
let s1 = LdBG::from_sequences("sr".to_string(), kmer_size, &sr_seqs);

let m = MLdBG::from_ldbgs(vec![l1, s1])
.score_kmers(model_path)
.score_kmers(&model_absolute_path)
.collapse()
.clean(0.1)
.build_links(&lr_seqs, false);

let g = m.traverse_all_kmers();

let mut window_corrections = HashMap::new();
for (id, seq) in long_reads {
let corrected_seq = m.correct_seq(&g, &seq);

corrected_reads.entry(id)
window_corrections.entry(id)
.or_insert_with(Vec::new)
.push((corrected_seq, m.scores.clone()));
}
}
}

window_corrections
})
.reduce(
|| HashMap::new(),
|mut acc, window_map| {
for (id, corrections) in window_map {
acc.entry(id)
.or_insert_with(Vec::new)
.extend(corrections);
}
acc
}
);

let mut file = fa_file.lock().unwrap();
for id in corrected_reads.keys() {
Expand Down Expand Up @@ -131,7 +159,7 @@ pub fn start(
let mut qual = vec![0; joined_seq.len()];
for i in 0..prob.len() {
prob[i] = prob[i].powf(1.0/count[i] as f32);
qual[i] = ((-10.0*(1.0 - (prob[i] - 0.0001).max(0.0)).log10()) as u8).clamp(1, 40);
qual[i] = ((-10.0*(1.0 - (prob[i] - 0.0001).max(0.0)).log10()) as u8 + 33).clamp(33, 73);
}

let _ = writeln!(file, "@{}\n{}\n+\n{}", id, String::from_utf8(joined_seq).unwrap(), String::from_utf8_lossy(&qual));
Expand Down
2 changes: 1 addition & 1 deletion src/hidive/src/rescue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub fn start(
found_items
.load(Ordering::Relaxed)
.to_formatted_string(&Locale::en),
tid_to_chrom.get(&read.tid()).unwrap(),
tid_to_chrom.get(&read.tid()).unwrap_or(&"unknown".to_string()),
read.reference_start().to_formatted_string(&Locale::en)
));
progress_bar.inc(UPDATE_FREQUENCY as u64);
Expand Down
2 changes: 1 addition & 1 deletion src/skydive/src/ldbg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ impl LdBG {
}
}

crate::elog!("Replacing {}-{} in read of length {}\n{}\n{}", start_pos, end_pos, b.len(), q1, r1);
// crate::elog!("Replacing {}-{} in read of length {}\n{}\n{}", start_pos, end_pos, b.len(), q1, r1);

if start_pos <= end_pos {
b.splice(start_pos..=end_pos, replacement_path);
Expand Down
10 changes: 0 additions & 10 deletions src/skydive/src/mldbg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,11 @@ impl MLdBG {
pub fn score_kmers(mut self, model_path: &PathBuf) -> Self {
let gbdt = GBDT::load_model(model_path.to_str().unwrap()).unwrap();

// let lr_contigs = self.ldbgs[0].assemble_all();
// let lr_distances = Self::distance_to_a_contig_end(&lr_contigs, self.kmer_size);

// let sr_contigs = self.ldbgs[1].assemble_all();
// let sr_distances = Self::distance_to_a_contig_end(&sr_contigs, self.kmer_size);

self.scores = self.union_of_kmers().iter().map(|cn_kmer| {
let compressed_len = crate::utils::homopolymer_compressed(cn_kmer).len();
let compressed_len_diff = (cn_kmer.len() - compressed_len) as f32;
let entropy = crate::utils::shannon_entropy(cn_kmer);
let gc_content = crate::utils::gc_content(cn_kmer);
// let lr_distance = *lr_distances.get(cn_kmer).unwrap_or(&0) as f32;
// let sr_distance = *sr_distances.get(cn_kmer).unwrap_or(&0) as f32;

let lcov = self.ldbgs[0].kmers.get(cn_kmer).map_or(0, |record| record.coverage());

Expand All @@ -189,8 +181,6 @@ impl MLdBG {
compressed_len_diff, // homopolymer compression length difference
entropy, // shannon entropy
gc_content, // gc content
// lr_distance, // distance to nearest long read contig end
// sr_distance, // distance to nearest short read contig end
];

let data = Data::new_test_data(features, None);
Expand Down
77 changes: 64 additions & 13 deletions src/skydive/src/wmec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
//! HAL ID: hal-01225988
use std::collections::{BTreeSet, HashMap};

use std::fs::File;
use std::io::Write;

use itertools::Itertools;

use indicatif::ParallelProgressIterator;
use rayon::prelude::*;

#[derive(Debug)]
Expand Down Expand Up @@ -136,7 +138,7 @@ fn initialize_dp(data: &WMECData) -> (HashMap<(usize, BTreeSet<usize>, BTreeSet<
}

// Function to update the DP table for each SNP position
fn update_dp(data: &WMECData, dp: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>, backtrack: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), Option<(BTreeSet<usize>, BTreeSet<usize>)>>, snp: usize) {
fn update_dp_old(data: &WMECData, dp: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>, backtrack: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), Option<(BTreeSet<usize>, BTreeSet<usize>)>>, snp: usize) {
let active_fragments: BTreeSet<usize> = data.reads.iter().enumerate()
.filter(|(_, read)| read[snp].is_some()) // Only consider fragments covering SNP
.map(|(index, _)| index)
Expand Down Expand Up @@ -174,6 +176,53 @@ fn update_dp(data: &WMECData, dp: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet
}
}

fn update_dp(data: &WMECData, dp: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>, backtrack: &mut HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), Option<(BTreeSet<usize>, BTreeSet<usize>)>>, snp: usize) {
let active_fragments: BTreeSet<usize> = data.reads.iter().enumerate()
.filter(|(_, read)| read[snp].is_some())
.map(|(index, _)| index)
.collect();
let partitions = generate_bipartitions(&active_fragments);

// Pre-compute prev_active_fragments since it's used by all iterations
let prev_active_fragments: BTreeSet<usize> = data.reads.iter().enumerate()
.filter(|(_, read)| read[snp - 1].is_some())
.map(|(index, _)| index)
.collect();

// Collect results in parallel
let results: Vec<_> = partitions.par_iter()
.map(|(r, s)| {
let delta_cost = data.delta_c(snp, r, s);
let mut min_cost = u32::MAX;
let mut best_bipartition = None;

for (prev_r, prev_s) in generate_bipartitions(&prev_active_fragments) {
let r_compatible = r.intersection(&prev_active_fragments).all(|&x| prev_r.contains(&x));
let s_compatible = s.intersection(&prev_active_fragments).all(|&x| prev_s.contains(&x));

if r_compatible && s_compatible {
if let Some(&prev_cost) = dp.get(&(snp - 1, prev_r.clone(), prev_s.clone())) {
let current_cost = delta_cost + prev_cost;

if current_cost < min_cost {
min_cost = current_cost;
best_bipartition = Some((prev_r.clone(), prev_s.clone()));
}
}
}
}

((r.clone(), s.clone()), (min_cost, best_bipartition))
})
.collect();

// Update the hashmaps with the results
for ((r, s), (min_cost, best_bipartition)) in results {
dp.insert((snp, r.clone(), s.clone()), min_cost);
backtrack.insert((snp, r, s), best_bipartition);
}
}

fn backtrack_haplotypes(data: &WMECData, dp: &HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), u32>, backtrack: &HashMap<(usize, BTreeSet<usize>, BTreeSet<usize>), Option<(BTreeSet<usize>, BTreeSet<usize>)>>) -> (Vec<u8>, Vec<u8>) {
let mut best_cost = u32::MAX;
let mut best_bipartition = None;
Expand Down Expand Up @@ -250,28 +299,30 @@ pub fn phase_all(data: &WMECData, window: usize, stride: usize) -> (Vec<u8>, Vec
}
}

// crate::elog!("Windows: {:?}", windows);
let pb = crate::utils::default_bounded_progress_bar("Processing windows", windows.len() as u64);

// Process windows in parallel
let haplotype_pairs: Vec<_> = windows.par_iter()
let haplotype_pairs: Vec<_> = windows
// .iter()
.par_iter()
.progress_with(pb)
.map(|&(start, end)| {
let (window_reads, window_confidences): (Vec<Vec<Option<u8>>>, Vec<Vec<Option<u32>>>) = data.reads.iter().zip(data.confidences.iter())
.filter_map(|(read, confidence)| {
.map(|(read, confidence)| {
let window_read = read[start..end].to_vec();
if window_read.iter().any(|x| x.is_some()) {
Some((window_read, confidence[start..end].to_vec()))
} else {
None
}
let none_count = window_read.iter().filter(|x| x.is_none()).count();
(none_count, window_read, confidence[start..end].to_vec())
})
.collect::<Vec<_>>()
.into_iter()
.sorted_by_key(|(none_count, _, _)| *none_count)
.take(10)
.map(|(_, window_read, window_confidence)| (window_read, window_confidence))
.unzip();

let window_data = WMECData::new(window_reads, window_confidences);

// crate::elog!("Processing window {} to {} ({})", start, end, window_data.num_snps);

let (mut dp, mut backtrack) = initialize_dp(&window_data);

for snp in 1..window_data.num_snps {
update_dp(&window_data, &mut dp, &mut backtrack, snp);
}
Expand Down
Loading

0 comments on commit bd0a858

Please sign in to comment.