-
-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathmain.rs
64 lines (58 loc) · 2.06 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
extern crate csv;
extern crate itertools;
extern crate lightgbm;
extern crate serde_json;
use itertools::zip;
use lightgbm::{Booster, Dataset};
use serde_json::json;
fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
let rdr = csv::ReaderBuilder::new()
.has_headers(false)
.delimiter(b'\t')
.from_path(file_path);
let mut labels: Vec<f32> = Vec::new();
let mut features: Vec<Vec<f64>> = Vec::new();
for result in rdr.unwrap().records() {
let record = result.unwrap();
let label = record[0].parse::<f32>().unwrap();
let feature: Vec<f64> = record
.iter()
.map(|x| x.parse::<f64>().unwrap())
.collect::<Vec<f64>>()[1..]
.to_vec();
labels.push(label);
features.push(feature);
}
(features, labels)
}
fn main() -> std::io::Result<()> {
let (train_features, train_labels) =
load_file("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.train");
let (test_features, test_labels) =
load_file("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.test");
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();
let params = json! {
{
"num_iterations": 100,
"objective": "binary",
"metric": "auc"
}
};
let booster = Booster::train(train_dataset, ¶ms).unwrap();
let result = booster.predict(test_features).unwrap();
let mut tp = 0;
for (label, pred) in zip(&test_labels, &result[0]) {
if (*label == 1_f32 && *pred > 0.5_f64) || (*label == 0_f32 && *pred <= 0.5_f64) {
tp += 1;
}
println!("{}, {}", label, pred)
}
println!("feature importance");
let feature_name = booster.feature_name().unwrap();
let feature_importance = booster.feature_importance().unwrap();
for (feature, importance) in zip(&feature_name, &feature_importance) {
println!("{}: {}", feature, importance);
}
println!("result: {} / {}", &tp, result[0].len());
Ok(())
}