Skip to content

Commit

Permalink
ICE with weight_init::which_dis
Browse files Browse the repository at this point in the history
  • Loading branch information
Thibaut-Le-Goff authored Jan 3, 2023
1 parent f5d733d commit 15a51c1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/runst.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
///////////////////// Network initialisation //////////////////////////
pub fn net_init(network_struct: &Vec<usize>, distrib: &str) -> (Vec<Vec<f32>>, Vec<Vec<f32>>) {

/*
///// list of the available functions /////
type FunType = Box<dyn Fn(usize, usize)->Vec<f32>>;
// must be of the type of the output and input of the function to call
// linking the functions(FunType) to their name(&str):
let mut functions: Vec<(FunType, &str)> = Vec::new();
functions.push((Box::new(weight_init::uniform_dis), "uniform_dis"));
functions.push((Box::new(weight_init::normal_dis), "normal_dis"));
let mut dist_list: Vec<(FunType, &str)> = Vec::new();
dist_list.push((Box::new(weight_init::uniform_dis), "uniform_dis"));
dist_list.push((Box::new(weight_init::normal_dis), "normal_dis"));
let mut function_to_call_i: usize = 0;
for i in 0..functions.len() {
if functions[i].1 == distrib {
for i in 0..dist_list.len() {
if dist_list[i].1 == distrib {
function_to_call_i = i;
}
}
*/
///////////////////////////////////////////

let (function_to_call_i, dist_list) = weight_init::which_dis(distrib);

let mut weights_tensor: Vec<Vec<f32>> = Vec::new();
let mut bias_matrix: Vec<Vec<f32>> = Vec::new();
Expand All @@ -26,10 +29,10 @@ pub fn net_init(network_struct: &Vec<usize>, distrib: &str) -> (Vec<Vec<f32>>, V

// create the weights and the bias between the layers:
for i in 0..network_struct.len() - 1 {
// the things between x things is equal to x - 1
// the number of things between x things is equal to x - 1
next_layer = i + 1;

let weight_matrix: Vec<f32> = functions[function_to_call_i].0(network_struct[i], network_struct[next_layer]);
let weight_matrix: Vec<f32> = dist_list[function_to_call_i].0(network_struct[i], network_struct[next_layer]);
weights_tensor.push(weight_matrix);

let bias_vector: Vec<f32> = vec![0.0; network_struct[next_layer]];
Expand Down
25 changes: 25 additions & 0 deletions src/runst/weight_init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,29 @@ pub fn he_uniform_dis(column: usize, row: usize) -> Vec<f32> {
let matrix: Vec<f32> = random(column, row, a, b);

return matrix;
}

pub fn which_dis(dis: &str) -> (usize, Vec<(Box<dyn Fn(usize, usize)->Vec<f32>>, &str)>) {
///// list of the available distribution /////
type FunType = Box<dyn Fn(usize, usize)->Vec<f32>>;
// must be of the type of the output and input of the function to call

// linking the functions(FunType) to their name(&str):
let mut dist_list: Vec<(FunType, &str)> = Vec::new();
dist_list.push((Box::new(uniform_dis), "uniform_dis"));
dist_list.push((Box::new(normal_dis), "normal_dis"));
//dist_list.push((Box::new(xav_gro_normal_dis), "xav_gro_normal_dis"));
//dist_list.push((Box::new(xav_gro_uniform_dis), "xav_gro_uniform_dis"));
dist_list.push((Box::new(he_normal_dis), "he_normal_dis"));
dist_list.push((Box::new(he_uniform_dis), "he_uniform_dis"));

let mut function_to_call_i: usize = 0;

for i in 0..dist_list.len() {
if dist_list[i].1 == dis {
function_to_call_i = i;
}
}

return (function_to_call_i, dist_list);
}

0 comments on commit 15a51c1

Please sign in to comment.