-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdef_y_layer.hpp
39 lines (27 loc) · 1.11 KB
/
def_y_layer.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#pragma once
class PredictionStats {
public:
virtual void pretty_print() const = 0;
virtual vector<string> vals() const = 0;
virtual vector<double> numeric_vals() const = 0;
};
class DEFYLayer {
public:
struct LogPRowCol {
shared_ptr<arma::rowvec> log_p_row_train;
shared_ptr<arma::colvec> log_p_col_train;
shared_ptr<arma::rowvec> log_p_row_test;
shared_ptr<arma::colvec> log_p_col_test;
};
// row/column sum of log p(y)
virtual LogPRowCol log_p_row_column(shared_ptr<arma::mat> z1,
shared_ptr<arma::mat> z2,
const ExampleIds& example_ids) = 0;
virtual LogPRowCol log_p_row_column(shared_ptr<arma::mat> z1,
shared_ptr<arma::mat> z2) = 0;
virtual LogPRowCol log_likelihood_row_column(shared_ptr<arma::mat> z1,
shared_ptr<arma::mat> z2) = 0;
virtual shared_ptr<PredictionStats> prediction_stats(shared_ptr<arma::mat> z1,
shared_ptr<arma::mat> z2) = 0;
virtual vector<string> prediction_header() const = 0;
};