-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdef_gaussian_layer.hpp
182 lines (155 loc) · 5.44 KB
/
def_gaussian_layer.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#pragma once
#include <cassert>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
#include <gsl/gsl_sf.h>
#include "utils.hpp"
#include "def_layer.hpp"
#include "link_function.hpp"
#include "serialization.hpp"
// the prior will always be 0-centered
class GaussianPriorLayer : public DEFPriorLayer {
protected:
pt::ptree options;
double mu;
double sigma;
double log_sqrt_2pi;
double sigma2;
public:
GaussianPriorLayer(const pt::ptree& options, const DEFInitializer& initializer)
: options( options ) {
mu = options.get<double>("layer.mu", 0.0);
if (mu != 0.0) {
LOG(debug) << "gaussian prior layer mu=" << mu;
assert(mu == options.get<double>("layer.w_mu_init_offset"));
}
sigma = options.get<double>("layer.sigma");
log_sqrt_2pi = log(sqrt(2*M_PI));
sigma2 = sigma * sigma;
}
virtual double compute_log_p(double z) {
z -= mu;
return - log_sqrt_2pi - log(sigma) - z*z / (2*sigma2);
}
};
class GaussianFactorizedLayer : public InferenceFactorizedLayer {
protected:
arma::uword layer_size;
Serializable<arma::mat> w_mu, w_sigma;
LinkFunction* lf;
public:
virtual double compute_log_q(double z, arma::uword i, arma::uword j) {
// mu always uses the identity link function
auto mu = w_mu(i,j);
auto sigma = lf->f(w_sigma(i,j));
auto log_q = - log(2*M_PI)*0.5 - log(sigma) - ((z-mu)*(z-mu)) / (2*sigma*sigma);
LOG_IF(fatal, !isfinite(log_q))
<< "mu=" << mu << " sigma=" << sigma
<< " z=" << z << " log_q=" << log_q;
assert(isfinite(log_q));
return log_q;
}
virtual double sample(gsl_rng* rng, arma::uword i, arma::uword j) {
// mu always uses the identity link function
auto mu = w_mu(i,j);
auto sigma = lf->f(w_sigma(i,j));
auto z = gsl_ran_gaussian(rng, sigma) + mu;
return z;
}
virtual double mean(arma::uword i, arma::uword j) {
return w_mu(i,j);
}
virtual void copy_params(InferenceFactorizedLayer* other) {
GaussianFactorizedLayer* other_gfl = dynamic_cast<GaussianFactorizedLayer*>(other);
if (other_gfl == NULL)
throw runtime_error("can't cast to GaussianFactorizedLayer");
w_mu = other_gfl->w_mu;
w_sigma = other_gfl->w_sigma;
}
virtual void truncate(const ExampleIds& example_ids) {
// fixed variance
if (options.get("global.fixed_gaussian_sigma", 0.0)) {
auto fixed_sigma = options.get<double>("global.fixed_gaussian_sigma");
auto fixed_sigma0 = lf->f_inv(fixed_sigma);
for(auto j : example_ids) {
w_sigma.col(j).transform([=](double s) { return fixed_sigma0; });
}
}
else { // no fixed variance
double min_sigma = options.get("global.min_gaussian_sigma", 0.0);
// no truncation
if (min_sigma == 0)
return;
auto min_sigma0 = lf->f_inv(min_sigma);
for(auto j : example_ids) {
w_sigma.col(j).transform([=](double v) { return max(v, min_sigma0); });
}
}
}
virtual void truncate() {
truncate(all_examples);
}
GaussianFactorizedLayer() {}
GaussianFactorizedLayer(const pt::ptree& options,
const DEFInitializer& initializer)
: InferenceFactorizedLayer(options) {
init(false);
gsl_rng* rng = initializer.rng;
auto w_mu_init = options.get<double>("layer.w_mu_init");
auto w_mu_init_offset = options.get<double>("layer.w_mu_init_offset", 0.0);
if (w_mu_init_offset != 0) {
LOG(debug) << "guassian factorized layer w_mu_init_offset=" << w_mu_init_offset;
assert(w_mu_init_offset == options.get<double>("layer.mu"));
}
for(auto& v : w_mu) {
// use gaussian to initilize mu
v = gsl_ran_gaussian(rng, 1) * w_mu_init + w_mu_init_offset;
}
auto w_sigma_init = options.get<double>("layer.w_sigma_init");
for(auto& v : w_sigma) {
// use gaussian to intilize sigma before the link_function
v = gsl_ran_gaussian(rng, 1) * w_sigma_init;
}
auto min_sigma = options.get("global.min_gaussian_sigma", 0.0);
if (min_sigma > 0)
LOG(debug) << "global.min_gaussian_sigma=" << min_sigma;
auto fixed_sigma = options.get("global.fixed_gaussian_sigma", 0.0);
if (fixed_sigma > 0) {
LOG(debug) << "global.fixed_gaussian_sigma=" << fixed_sigma;
}
}
void init(bool deserialize) {
layer_size = options.get<int>("layer.size");
lf = get_link_function(options.get<string>("lf"));
w_mu.set_size(layer_size, n_examples);
w_sigma.set_size(layer_size, n_examples);
ScoreFunction score_mu = [=](double z, arma::uword i, arma::uword j) {
auto mu = w_mu(i,j);
auto sigma = lf->f(w_sigma(i,j));
return (z-mu) / (sigma*sigma);
};
register_param(&w_mu, score_mu, deserialize);
ScoreFunction score_sigma = [=](double z, arma::uword i, arma::uword j) {
auto mu = w_mu(i,j);
auto sigma0 = w_sigma(i,j);
auto sigma = lf->f(sigma0);
return lf->g(sigma0) * (-1.0/sigma + (z-mu)*(z-mu) / (sigma*sigma*sigma));
};
register_param(&w_sigma, score_sigma, deserialize);
}
friend class boost::serialization::access;
BOOST_SERIALIZATION_SPLIT_MEMBER();
template<class Archive>
void save(Archive& ar, const unsigned int) const {
ar & w_mu;
ar & w_sigma;
ar & boost::serialization::base_object<const InferenceFactorizedLayer>(*this);
}
template<class Archive>
void load(Archive& ar, const unsigned int) {
ar & w_mu;
ar & w_sigma;
ar & boost::serialization::base_object<InferenceFactorizedLayer>(*this);
init(true);
}
};