Skip to content

Commit

Permalink
one hot encode to include 'N'
Browse files Browse the repository at this point in the history
bshifaw committed Dec 12, 2024
1 parent 23d8f25 commit cfd7da2
Showing 2 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/hidive/src/train.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ use std::path::PathBuf;
use std::iter::Chain;
use std::collections::hash_map::Keys;

use skydive::nn_model::{KmerData, KmerDataVec, KmerNN, train_model, evaluate_model, prepare_tensors, split_data};
use skydive::nn_model::{KmerData, KmerDataVec, KmerNN, train_model, evaluate_model, prepare_tensors, split_data, one_hot_encode_from_dna_ascii};
use candle_nn::{Module, Optimizer, VarBuilder, VarMap};
use candle_core::{DType, Device};
const DEVICE: Device = Device::Cpu;
25 changes: 14 additions & 11 deletions src/skydive/src/nn_model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use candle_core::{DType, Device, Tensor};
use candle_nn::{linear, Linear, Module, Optimizer, VarBuilder, LayerNorm};
use candle_nn::{linear, Linear, Module, Optimizer, VarBuilder, LayerNorm, ops::sigmoid};
use rand::seq::SliceRandom;
use rand::thread_rng;

@@ -54,6 +54,7 @@ impl Module for KmerNN {
let x = self.ln2.forward(&x)?;
let x = x.relu()?;
let x = self.fc3.forward(&x)?;
// let x = sigmoid(&x)?;
Ok(x)
}
}
@@ -154,13 +155,14 @@ pub fn split_data(data: &[KmerData]) -> (Vec<KmerData>, Vec<KmerData>) {
/// One hot encode the seqnence data
/// The function returns a vector of one hot encoded dna sequence
pub fn one_hot_encode_from_dna_string(seq: &str) -> Vec<f32> {
let mut one_hot = vec![0.0; 4 * seq.len()];
let mut one_hot = vec![0.0; 5 * seq.len()];
for (i, c) in seq.chars().enumerate() {
match c {
'A' => one_hot[i * 4] = 1.0,
'C' => one_hot[i * 4 + 1] = 1.0,
'G' => one_hot[i * 4 + 2] = 1.0,
'T' => one_hot[i * 4 + 3] = 1.0,
'A' => one_hot[i * 5] = 1.0,
'C' => one_hot[i * 5 + 1] = 1.0,
'G' => one_hot[i * 5 + 2] = 1.0,
'T' => one_hot[i * 5 + 3] = 1.0,
'N' => one_hot[i * 5 + 4] = 1.0,
_ => {}
}
}
@@ -169,13 +171,14 @@ pub fn one_hot_encode_from_dna_string(seq: &str) -> Vec<f32> {

/// One hot encode the seqnence data from vec<u8>
pub fn one_hot_encode_from_dna_ascii(seq: &[u8]) -> Vec<f32> {
let mut one_hot = vec![0.0; 4 * seq.len()];
let mut one_hot = vec![0.0; 5 * seq.len()];
for (i, c) in seq.iter().enumerate() {
match c {
65 => one_hot[i * 4] = 1.0,
67 => one_hot[i * 4 + 1] = 1.0,
84 => one_hot[i * 4 + 2] = 1.0,
71 => one_hot[i * 4 + 3] = 1.0,
65 => one_hot[i * 5] = 1.0,
67 => one_hot[i * 5 + 1] = 1.0,
84 => one_hot[i * 5 + 2] = 1.0,
71 => one_hot[i * 5 + 3] = 1.0,
78 => one_hot[i * 5 + 4] = 1.0,
_ => {}
}
}

0 comments on commit cfd7da2

Please sign in to comment.