forked from pengzhao-intel/keras_nmt
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsearch.py
334 lines (271 loc) · 13.9 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
# models: BeamSearch
import numpy
import copy
class BeamSearch(object):
def __init__(self, enc_dec, configuration, beam_size=1, maxlen=50, stochastic=True):
# with_attention=True, with_coverage=False, coverage_dim=1, coverage_type='linguistic', max_fertility=2, with_reconstruction=False, reconstruction_weight, with_reconstruction_error_on_states=False,):
self.enc_dec = enc_dec
# if sampling, beam_size = 1
self.beam_size = beam_size
# max length of output sentence
self.maxlen = maxlen
# stochastic == True stands for sampling
self.stochastic = stochastic
self.with_attention = configuration['with_attention']
self.with_coverage = configuration['with_coverage']
self.coverage_dim = configuration['coverage_dim']
self.coverage_type = configuration['coverage_type']
self.max_fertility = configuration['max_fertility']
self.with_reconstruction = configuration['with_reconstruction']
self.reconstruction_weight = configuration['reconstruction_weight']
if self.beam_size > 1:
assert not self.stochastic, 'Beam search does not support stochastic sampling'
def apply(self, input):
# shape of input: time_steps , 1
sample = []
sample_score = []
if self.stochastic:
sample_score = 0
# for reconstruction
sample_states = []
if self.with_attention:
sample_alignment = []
if self.with_coverage:
sample_coverage = []
# get initial state of decoder rnn and encoder context
ret = self.enc_dec.compile_init_and_context([input])
next_state, c = ret[0], ret[1]
if not self.with_attention:
init_ctx0 = ret[2]
live_k = 1
dead_k = 0
hyp_samples = [[]] * live_k
hyp_scores = numpy.zeros(live_k).astype('float32')
hyp_states = [[]] * live_k
if self.with_attention:
hyp_alignments = [[]]
# note that batch size is the second dimension coverage and will be used in the later decoding, thus we need a structure different from the above ones
if self.with_coverage:
hyp_coverages = numpy.zeros((c.shape[0], 1, self.coverage_dim), dtype='float32')
if self.coverage_type is 'linguistic':
# note the return result is a list even when it contains only one element
fertility = self.enc_dec.compile_fertility([c])[0]
# bos indicator
next_w = -1 * numpy.ones((1,)).astype('int32')
for i in range(self.maxlen):
inps = [next_w, next_state]
if self.with_attention:
ctx = numpy.tile(c, [live_k, 1])
inps.append(ctx)
if self.with_coverage:
inps.append(hyp_coverages)
else:
init_ctx = numpy.tile(init_ctx0, [live_k, 1])
inps.append(init_ctx)
ret = self.enc_dec.compile_next_state_and_probs(inps)
next_p, next_state, next_w = ret[0], ret[1], ret[2]
if self.with_attention:
alignment = ret[3]
# update the coverage after attention operation
if self.with_coverage:
coverages = ret[4]
if self.stochastic:
nw = next_w[0]
sample.append(nw)
sample_states.append(next_state[0])
sample_score += next_p[0, nw]
if self.with_attention and self.with_coverage:
hyp_coverages = coverages
# 0 for EOS
if nw == 0:
break
else:
cand_scores = hyp_scores[:, None] - numpy.log(next_p)
cand_flat = cand_scores.flatten()
ranks_flat = cand_flat.argsort()[:self.beam_size - dead_k]
voc_size = next_p.shape[1]
trans_indices = ranks_flat / voc_size
word_indices = ranks_flat % voc_size
costs = cand_flat[ranks_flat]
new_hyp_samples = []
new_hyp_scores = numpy.zeros(self.beam_size - dead_k).astype('float32')
new_hyp_states = []
if self.with_attention:
new_hyp_alignments = []
if self.with_coverage:
new_hyp_coverages = numpy.zeros((c.shape[0], self.beam_size - dead_k, self.coverage_dim),
dtype='float32')
for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
new_hyp_samples.append(hyp_samples[ti] + [wi])
new_hyp_scores[idx] = copy.copy(costs[idx])
new_hyp_states.append(hyp_states[ti] + [copy.copy(next_state[ti])])
if self.with_attention:
new_hyp_alignments.append(hyp_alignments[ti] + [alignment[:, ti]])
if self.with_coverage:
new_hyp_coverages[:, idx, :] = coverages[:, ti, :]
# check the finished samples
new_live_k = 0
hyp_samples = []
hyp_scores = []
hyp_states = []
if self.with_attention:
hyp_alignments = []
if self.with_coverage:
indices = []
for idx in range(len(new_hyp_samples)):
if new_hyp_samples[idx][-1] == 0:
sample.append(new_hyp_samples[idx])
sample_score.append(new_hyp_scores[idx])
if self.with_attention:
# for reconstruction
sample_states.append(new_hyp_states[idx])
sample_alignment.append(new_hyp_alignments[idx])
if self.with_coverage:
# for neural coverage, we use the mean value of the vector
sample_coverage.append(new_hyp_coverages[:, idx, :].mean(1))
dead_k += 1
else:
hyp_samples.append(new_hyp_samples[idx])
hyp_scores.append(new_hyp_scores[idx])
hyp_states.append(new_hyp_states[idx])
if self.with_attention:
hyp_alignments.append(new_hyp_alignments[idx])
if self.with_coverage:
indices.append(idx)
new_live_k += 1
hyp_scores = numpy.array(hyp_scores)
live_k = new_live_k
if self.with_attention:
if self.with_coverage:
# note now liv_k has changed
hyp_coverages = numpy.zeros((c.shape[0], live_k, self.coverage_dim), dtype='float32')
for idx in xrange(live_k):
hyp_coverages[:, idx, :] = new_hyp_coverages[:, indices[idx], :]
if live_k < 1 or dead_k >= self.beam_size:
break
next_w = numpy.array([w[-1] for w in hyp_samples])
next_state = numpy.array([s[-1] for s in hyp_states])
if not self.stochastic:
# dump every remaining one
if live_k > 0:
for idx in range(live_k):
sample.append(hyp_samples[idx])
sample_score.append(hyp_scores[idx])
sample_states.append(hyp_states[idx])
if self.with_attention:
sample_alignment.append(hyp_alignments[idx])
if self.with_coverage:
sample_coverage.append(hyp_coverages[:, idx, :].mean(1))
else:
if self.with_attention and self.with_coverage:
sample_coverage = hyp_coverages[:, 0, :].mean(1)
# for reconstruction
if self.with_reconstruction:
# build inverce_c and mask
if self.stochastic:
sample_states = [sample_states]
# sample_num could be 1 ore greater, for example, beam search
sample_num = len(sample_states)
inverse_sample_score = numpy.zeros(sample_num).astype('float32')
if self.with_attention:
inverse_sample_alignment = [[] for i in xrange(sample_num)]
my = max([len(s) for s in sample_states])
inverse_c = numpy.zeros((my, sample_num, sample_states[0][0].shape[0]), dtype='float32')
# mask shape: time_steps, nb_samples
mask = numpy.zeros((my, sample_num), dtype='float32')
for idx in range(sample_num):
inverse_c[:len(sample_states[idx]), idx, :] = sample_states[idx]
mask[:len(sample_states[idx]), idx] = 1.
# get initial state of decoder rnn and encoder context
inverse_ret = self.enc_dec.compile_inverse_init_and_context([inverse_c])
inverse_next_state = inverse_ret[0]
if not self.with_attention:
inverse_init_ctx0 = inverse_ret[1]
to_reconstruct_input = input[:, 0] # time_steps, 1D array
for i in range(len(to_reconstruct_input)):
# whether input contains eos?
inverse_next_w = numpy.array([to_reconstruct_input[i - 1]] * sample_num) if i > 0 else -1 * numpy.ones(
(sample_num,)).astype('int32')
inps = [inverse_next_w, mask, inverse_next_state]
if self.with_attention:
inps.append(inverse_c)
else:
inps.append(inverse_init_ctx0)
ret = self.enc_dec.compile_inverse_next_state_and_probs(inps)
inverse_next_p, inverse_next_state, inverse_next_w = ret[0], ret[1], ret[2]
if self.with_attention:
inverse_alignment = ret[3]
# compute reconstruction error
inverse_sample_score -= numpy.log(inverse_next_p[:, to_reconstruct_input[i]])
# for each sample
for idx in range(sample_num):
inverse_sample_alignment[idx].append(inverse_alignment[:len(sample_states[idx]), idx])
# combine sample_score and reconstructed_score
sample_score += inverse_sample_score * self.reconstruction_weight
results = [sample, sample_score]
if self.with_attention:
results.append(sample_alignment)
if self.with_coverage:
results.append(sample_coverage)
if self.coverage_type is 'linguistic':
results.append(fertility[:, 0])
if self.with_reconstruction:
results.append(inverse_sample_score)
results.append(inverse_sample_alignment)
return results
# for forced alignment
class Align(object):
def __init__(self, enc_dec, with_attention=True, with_coverage=False, coverage_dim=1, coverage_type='linguistic',
max_fertility=2, with_reconstruction=False):
self.enc_dec = enc_dec
assert with_attention, "Align only supports attention model"
self.with_coverage = with_coverage
self.coverage_dim = coverage_dim
self.coverage_type = coverage_type
self.max_fertility = max_fertility
self.with_reconstruction = with_reconstruction
def apply(self, source, target):
alignment = []
# get initial state of decoder rnn and encoder context
ret = self.enc_dec.compile_init_and_context([source])
next_state, ctx = ret[0], ret[1]
decoder_states = numpy.zeros((target.shape[0], 1, next_state.shape[1]), dtype='float32')
if self.with_coverage:
coverage = numpy.zeros((ctx.shape[0], 1, self.coverage_dim), dtype='float32')
if self.coverage_type is 'linguistic':
# note the return result is a list even when it contains only one element
fertility = self.enc_dec.compile_fertility([ctx])[0]
for i in range(len(target)):
next_w = numpy.array(target[i - 1]) if i > 0 else -1 * numpy.ones((1,)).astype('int32')
inps = [next_w, next_state, ctx]
if self.with_coverage:
inps.append(coverage)
if self.coverage_type is 'linguistic':
inps.append(fertility)
ret = self.enc_dec.compile_next_state_and_probs(inps)
_, next_state, next_w, align = ret[0], ret[1], ret[2], ret[3]
# update the coverage after attention operation
if self.with_coverage:
coverage = ret[4]
alignment.append(align[:, 0])
decoder_states[i, 0, :] = next_state[0, :]
if self.with_reconstruction:
inverse_alignment = []
ret = self.enc_dec.compile_inverse_init_and_context([decoder_states])
inverse_next_state = ret[0]
mask = numpy.ones((decoder_states.shape[0], 1), dtype='float32')
for i in range(len(source)):
inverse_next_w = numpy.array(source[i - 1]) if i > 0 else -1 * numpy.ones((1,)).astype('int32')
inps = [inverse_next_w, mask, inverse_next_state, decoder_states]
ret = self.enc_dec.compile_inverse_next_state_and_probs(inps)
_, inverse_next_state, inverse_next_w, inverse_align = ret[0], ret[1], ret[2], ret[3]
inverse_alignment.append(inverse_align[:, 0])
results = [alignment]
if self.with_coverage:
coverage = coverage[:, 0, :].mean(1)
results.append(coverage)
if self.coverage_type is 'linguistic':
results.append(fertility[:, 0])
if self.with_reconstruction:
results.append(inverse_alignment)
return results