-
Notifications
You must be signed in to change notification settings - Fork 97
/
test_atten.py
143 lines (119 loc) · 4.36 KB
/
test_atten.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
测试并展示 attention
"""
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import tensorflow as tf
from sequence_to_sequence import SequenceToSequence
from data_utils import batch_flow
def test(bidirectional, cell_type, depth, attention_type):
"""测试并展示attention图
"""
from tqdm import tqdm
from fake_data import generate
# 获取一些假数据
x_data, y_data, ws_input, ws_target = generate(size=10000)
# 训练部分
split = int(len(x_data) * 0.9)
x_train, x_test, y_train, y_test = (
x_data[:split], x_data[split:], y_data[:split], y_data[split:])
n_epoch = 2
batch_size = 32
steps = int(len(x_train) / batch_size) + 1
config = tf.ConfigProto(
device_count={'CPU': 1, 'GPU': 0},
allow_soft_placement=True,
log_device_placement=False
)
save_path = '/tmp/s2ss_atten.ckpt'
with tf.Graph().as_default():
model = SequenceToSequence(
input_vocab_size=len(ws_input),
target_vocab_size=len(ws_target),
batch_size=batch_size,
learning_rate=0.001,
bidirectional=bidirectional,
cell_type=cell_type,
depth=depth,
attention_type=attention_type,
parallel_iterations=1
)
init = tf.global_variables_initializer()
with tf.Session(config=config) as sess:
sess.run(init)
for epoch in range(1, n_epoch + 1):
costs = []
flow = batch_flow(
[x_train, y_train], [ws_input, ws_target], batch_size
)
bar = tqdm(range(steps),
desc='epoch {}, loss=0.000000'.format(epoch))
for _ in bar:
x, xl, y, yl = next(flow)
cost = model.train(sess, x, xl, y, yl)
costs.append(cost)
bar.set_description('epoch {} loss={:.6f}'.format(
epoch,
np.mean(costs)
))
model.save(sess, save_path)
# attention 展示 不能用 beam search 的
# 所以这里只是用 greedy
with tf.Graph().as_default():
model_pred = SequenceToSequence(
input_vocab_size=len(ws_input),
target_vocab_size=len(ws_target),
batch_size=1,
mode='decode',
beam_width=0,
bidirectional=bidirectional,
cell_type=cell_type,
depth=depth,
attention_type=attention_type,
parallel_iterations=1
)
init = tf.global_variables_initializer()
with tf.Session(config=config) as sess:
sess.run(init)
model_pred.load(sess, save_path)
pbar = batch_flow([x_test, y_test], [ws_input, ws_target], 1)
t = 0
for x, xl, y, yl in pbar:
pred, atten = model_pred.predict(
sess,
np.array(x),
np.array(xl),
attention=True
)
ox = ws_input.inverse_transform(x[0])
oy = ws_target.inverse_transform(y[0])
op = ws_target.inverse_transform(pred[0])
print(ox)
print(oy)
print(op)
fig, ax = plt.subplots()
cax = ax.matshow(atten.reshape(
[atten.shape[0], atten.shape[2]]
), cmap=cm.coolwarm)
ax.set_xticks(np.arange(len(ox)))
ax.set_yticks(np.arange(len(op)))
ax.set_xticklabels(ox)
ax.set_yticklabels(op)
fig.colorbar(cax)
plt.show()
print('-' * 30)
t += 1
if t >= 10:
break
if __name__ == '__main__':
# for bidirectional in (True, False):
# for cell_type in ('gru', 'lstm'):
# for depth in (1, 2, 3):
# for attention_type in ('Luong', 'Bahdanau'):
# print(
# 'bidirectional, cell_type, depth, attention_type',
# bidirectional, cell_type, depth, attention_type
# )
test(bidirectional=True, cell_type='lstm',
depth=2, attention_type='Bahdanau')