Skip to content

Commit

Permalink
adam & attention
Browse files Browse the repository at this point in the history
  • Loading branch information
haswelliris committed Mar 6, 2018
1 parent 97af532 commit 3b3209d
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 486 deletions.
12 changes: 6 additions & 6 deletions script/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@
training_config = {
'logdir' : 'logs', # logdir for log outputs and tensorboard
'tensorboard_freq' : 1, # tensorboard record frequence
'minibatch_size' : 500, # in samples when using ctf reader, per worker
'epoch_size' : 500, # in sequences, when using ctf reader
'log_freq' : 300, # in minibatchs
'minibatch_size' : 800, # in samples when using ctf reader, per worker
'epoch_size' : 800, # in sequences, when using ctf reader
'log_freq' : 400, # in minibatchs
'max_epochs' : 300,
'lr' : 10,
'lr' : 2,
'train_data' : 'train.ctf', # or 'train.tsv'
'val_data' : 'dev.ctf',
'val_interval' : 1, # interval in epochs to run validation
'stop_after' : 2, # num epochs to stop if no CV improvement
'minibatch_seqs' : 16, # num sequences of minibatch, when using tsv reader, per worker
'distributed_after' : 0, # num sequences after which to start distributed training
'gpu_pad' : 0, #emmmmmmm
'gpu_cnt' : 1, # number of gpus
'multi_gpu' : False, # using multi GPU training
'gpu_cnt' : 4, # number of gpus
'multi_gpu' : True, # using multi GPU training
}
4 changes: 0 additions & 4 deletions script/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,9 @@ def all_spans_loss(start_logits, start_y, end_logits, end_y):
return logZ - C.sequence.last(C.sequence.gather(start_logits, start_y)) - C.sequence.last(C.sequence.gather(end_logits, end_y))

def seq_hardmax(logits):
# [#][dim=1]
seq_max = C.layers.Fold(C.element_max, initial_state=C.constant(-1e+30, logits.shape))(logits)
# [#,c][dim] 找到最大单词的位置
s = C.equal(logits, C.sequence.broadcast_as(seq_max, logits))
# [#,c][dim] 找到第一个出现的最大单词的位置
s_acc = C.layers.Recurrence(C.plus)(s)
# 除了最大单词为其logits外,其他都为0
return s * C.equal(s_acc, 1) # only pick the first one

class LambdaFunc(C.ops.functions.UserFunction):
Expand Down
28 changes: 11 additions & 17 deletions script/polymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ def attention_layer(self, context, query):
#convert query's sequence axis to static
qvw, qvw_mask = C.sequence.unpack(q_processed, padding_value=0).outputs

# This part deserves some explanation
# It is the attention layer
# In the paper they use a 6 * dim dimensional vector
# here we split it in three parts because the different parts
# participate in very different operations
# so W * [h; u; h.* u] becomes w1 * h + w2 * u + w3 * (h.*u)
ws1 = C.parameter(shape=(2 * self.hidden_dim, 1), init=C.glorot_uniform())
ws2 = C.parameter(shape=(2 * self.hidden_dim, 1), init=C.glorot_uniform())
Expand All @@ -112,9 +117,11 @@ def attention_layer(self, context, query):
# qvw*ws3: [#][*,200], whu:[#,c][*]
# whu = C.reshape(C.reduce_sum(c_processed * C.sequence.broadcast_as(qvw * ws3, c_processed), axis=1), (-1,))
S1 = wh + C.sequence.broadcast_as(wu, c_processed) + att_bias # [#,c][*]
# mask out values outside of Query, and fill in gaps with -1e+30 as neutral value for both reduce_log_sum_exp and reduce_max
qvw_mask_expanded = C.sequence.broadcast_as(qvw_mask, c_processed)
S1 = C.element_select(qvw_mask_expanded, S1, C.constant(-1e+30))
q_attn = C.reshape(C.softmax(S1), (-1,1)) # [#,c][*,1]
#q_attn = print_node(q_attn)
c2q = C.reshape(C.reduce_sum(C.sequence.broadcast_as(qvw, q_attn) * q_attn, axis=0),(-1)) # [#,c][200]

max_col = C.reduce_max(S1) # [#,c][1] 最大的q中的单词
Expand All @@ -132,7 +139,6 @@ def attention_layer(self, context, query):
hh_attn = C.reshape(C.softmax(S2), (-1,1))
c2c = C.reshape(C.reduce_sum(C.sequence.broadcast_as(hvw, hh_attn)*hh_attn, axis=0), (-1,))

# 原始文档,题目表示,对问题单词文章表示,文章上下文表示
att_context = C.splice(c_processed, c2q, q2c_out, c2c)

return C.as_block(
Expand All @@ -142,9 +148,6 @@ def attention_layer(self, context, query):
'attention_layer')

def modeling_layer(self, attention_context):
'''
在第一遍阅读后,对文章的整体表示
'''
att_context = C.placeholder(shape=(8*self.hidden_dim,))
#modeling layer
# todo: use dropout in optimized_rnn_stack from cudnn once API exposes it
Expand All @@ -164,11 +167,9 @@ def output_layer(self, attention_context, modeling_context):
att_context = C.placeholder(shape=(8*self.hidden_dim,))
mod_context = C.placeholder(shape=(2*self.hidden_dim,))
#output layer
# 映射 [#,c][1]
start_logits = C.layers.Dense(1, name='out_start')(C.dropout(C.splice(mod_context, att_context), self.dropout))
if self.two_step:
start_hardmax = seq_hardmax(start_logits)
# 得到最大单词的语义表示 [#][dim]
att_mod_ctx = C.sequence.last(C.sequence.gather(mod_context, start_hardmax))
else:
start_prob = C.softmax(start_logits)
Expand Down Expand Up @@ -199,16 +200,12 @@ def model(self):

#input layer
c_processed, q_processed = self.input_layer(cgw,cnw,cc,qgw,qnw,qc).outputs
q_int = C.sequence.last(q_processed) # [#][2*hidden_dim]
# attention layer output:[#,c][8*hidden_dim]

# attention layer
att_context = self.attention_layer(c_processed, q_processed)

# modeling layer output:[#,c][2*hidden_dim]
# modeling layer
mod_context = self.modeling_layer(att_context)
q_int_extend = C.sequence.broadcast_as(q_int, c_processed) # [#,c][2*hidden_dim]
W_match = C.parameter(shape=(2*self.hidden_dim,2*self.hidden_dim), init=C.glorot_uniform())
match_level = C.sigmoid(C.reduce_sum(C.times_transpose(W_match, mod_context)*q_int_extend)) #[#,c][1]
mod_context = mod_context * match_level

# output layer
start_logits, end_logits = self.output_layer(att_context, mod_context).outputs
Expand All @@ -218,7 +215,4 @@ def model(self):
end_loss = seq_loss(end_logits, ae)
#paper_loss = start_loss + end_loss
new_loss = all_spans_loss(start_logits, ab, end_logits, ae)
new_loss.as_numpy = False
res = C.combine([start_logits, end_logits])
res.as_numpy=False
return res, new_loss
return C.combine([start_logits, end_logits]), new_loss
12 changes: 8 additions & 4 deletions script/train_pm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def train(data_path, model_path, log_file, config_file, restore=False, profiling
dummies.append(C.reduce_sum(C.assign(ema_p, 0.999 * ema_p + 0.001 * p)))
dummy = C.combine(dummies)

learner = C.adadelta(z.parameters, lr)
# learner = C.adadelta(z.parameters, lr)
momentum_as_time_constant = C.momentum_as_time_constant_schedule(700)
learner = C.adam(parameters=z.parameters, lr=lr,momentum=momentum_as_time_constant,
gradient_clipping_threshold_per_sample=30)

if C.Communicator.num_workers() > 1:
learner = C.data_parallel_distributed_learner(learner)
Expand All @@ -172,7 +175,8 @@ def train(data_path, model_path, log_file, config_file, restore=False, profiling
if restore and os.path.isfile(model_file):
trainer.restore_from_checkpoint(model_file)
#after restore always re-evaluate
epoch_stat['best_val_err'] = validate_model(os.path.join(data_path, training_config['val_data']), model, polymath,config_file)
# epoch_stat['best_val_err'] = validate_model(os.path.join(data_path, training_config['val_data']), model, polymath,config_file)
epoch_stat['best_val_err'] = 100

def post_epoch_work(epoch_stat):
trainer.summarize_training_progress()
Expand All @@ -193,7 +197,7 @@ def post_epoch_work(epoch_stat):
save_flag = True
fail_cnt = 0
while save_flag:
if fail_cnt > 100:
if fail_cnt > 1000:
print("ERROR: failed to save models")
break
try:
Expand Down Expand Up @@ -372,7 +376,7 @@ def test(test_data, model_path, model_file, config_file):
best_span_score = symbolic_best_span(begin_prediction, end_prediction)
predicted_span = C.layers.Recurrence(C.plus)(begin_prediction - C.sequence.past_value(end_prediction))

batch_size = 32 # in sequences
batch_size = 4 # in sequences
misc = {'rawctx':[], 'ctoken':[], 'answer':[], 'uid':[]}
tsv_reader = create_tsv_reader(loss, test_data, polymath, batch_size, 1, is_test=True, misc=misc)
results = {}
Expand Down
Loading

0 comments on commit 3b3209d

Please sign in to comment.