-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathestimator.h
67 lines (48 loc) · 1.3 KB
/
estimator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#ifndef STABLELDA_H_
#define STABLELDA_H_
#include <vector>
#include <iostream>
#include <string>
#include <map>
#include "nodes.h"
using namespace std;
class Estimator {
public:
double alpha;
double beta;
double eta;
int num_topics;
int num_words;
int rand_seed;
int num_docs;
vector<vector<int>> docs;
vector<vector<int>> samples;
vector<int> doc_lens;
vector<vector<int> > topical_clusters;
vector<vector<int>> mustlinks;
vector<vector<int>> cannotlinks;
ROOT root;
vector<int> leafmap;
vector<string> vocab;
map<string, int> vocab2id;
vector<ROOT> topics;
vector<vector<int>> nd;
vector<vector<double>> theta;
vector<vector<double>> phi;
Estimator(double alpha, double beta, double eta, int num_topics, int num_words, int rand_seed);
void load_data(string data_file, string z_file, string cluster_file, string vocab_file);
void estimate(int epochs);
virtual ~Estimator();
void print_topwords(int N=10);
void save(string output_path);
private:
vector<vector<int>> ml_cliques; //must-link connected components
vector<vector<int>> cl_cliques; //cannot-link connected components
void readin_data(string data_file);
void readin_vocab(string vocab_file);
void readin_clusters(string cluster_file);
void build_tree();
void calc_theta();
void calc_phi();
};
#endif /* STABLELDA_H_ */