-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathbbvi.hpp
39 lines (32 loc) · 1020 Bytes
/
bbvi.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#pragma once
#include "utils.hpp"
struct BBVIStats{
// before control variates
double mean_sqr_g0, var_g0;
// after control variates
double mean_sqr_g1, var_g1;
BBVIStats()
: mean_sqr_g0(), var_g0(), mean_sqr_g1(), var_g1(){
}
BBVIStats& operator+=(const BBVIStats& other) {
mean_sqr_g0 += other.mean_sqr_g0;
var_g0 += other.var_g0;
mean_sqr_g1 += other.mean_sqr_g1;
var_g1 += other.var_g1;
return *this;
}
BBVIStats& operator/=(double x) {
mean_sqr_g0 /= x;
var_g0 /= x;
mean_sqr_g1 /= x;
var_g1 /= x;
return *this;
}
};
void compute_mean_var(VecOfMat& list, arma::mat& mean, arma::mat& var);
shared_ptr<arma::mat> grad_bbvi_factorized(const pt::ptree& options,
const VecOfMat& grad_log_q,
const VecOfMat& log_p,
const VecOfMat& log_q,
BBVIStats& stats,
int threads);