-
Notifications
You must be signed in to change notification settings - Fork 4
/
Runner.py
64 lines (51 loc) · 2.27 KB
/
Runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
from DecTM import DecTM
class Runner(object):
def __init__(self, config, init_embeddings=None):
self.config = config
self.model = DecTM(self.config)
def train(self, X):
feed_dict = dict()
data_size = X.shape[0]
batch_size = self.config.batch_size
total_batch = int(data_size / batch_size)
for epoch in range(1, self.config.num_epoch + 1):
idx = np.arange(X.shape[0])
np.random.shuffle(idx)
loss = np.zeros((data_size,))
for i in range(total_batch):
start = i * batch_size
end = (i + 1) * batch_size
batch_input = X[idx[start:end]]
feed_dict[self.model.x] = batch_input
_, batch_loss = self.model.sess.run((self.model.optimizer, self.model.loss), feed_dict=feed_dict)
loss[start:end] = batch_loss
# the incompleted batch
feed_dict[self.model.x] = X[idx[-batch_size:]]
_, batch_loss = self.model.sess.run((self.model.optimizer, self.model.loss), feed_dict=feed_dict)
loss[-batch_size:] = batch_loss
if epoch % 5 == 0:
print("Epoch: {:03d} loss={:.3f}".format(epoch, np.mean(loss)))
beta = self.model.sess.run((self.model.beta)).T
return beta
def test(self, X):
data_size = X.shape[0]
batch_size = self.config.batch_size
theta = np.zeros((data_size, self.config.num_topic))
loss = np.zeros((data_size,))
var_tuple = (self.model.loss, self.model.theta)
for i in range(int(data_size / batch_size)):
start = i * batch_size
end = (i + 1) * batch_size
batch_input = X[start:end]
feed_dict = {self.model.x: batch_input}
batch_loss, batch_theta = self.model.sess.run(var_tuple, feed_dict=feed_dict)
loss[start:end] = batch_loss
theta[start:end] = batch_theta
batch_input = X[-batch_size:]
feed_dict = {self.model.x: batch_input}
batch_loss, batch_theta = self.model.sess.run(var_tuple, feed_dict=feed_dict)
loss[-batch_size:] = batch_loss
theta[-batch_size:] = batch_theta
loss = np.mean(loss)
return theta