-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathend2end_trimodal.py
60 lines (49 loc) · 1.94 KB
/
end2end_trimodal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import numpy as np
import tensorflow as tf
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from models.trimodals import E2E_Hierarchical_MCTN_Model
from utils.args import parse_args
from utils.data_loader import load_and_preprocess_data
from utils.utils import get_preds_statistics
np.random.seed(123)
tf.set_random_seed(456)
# arguments
args, configs = parse_args()
# data load
is_cycled = configs['translation1']['is_cycled']
feats_dict = \
load_and_preprocess_data(max_seq_len=configs['general']['max_seq_len'],
train_split=configs['general']['train_split'],
is_cycled=is_cycled)
print("FORMING SEQ2SEQ MODEL...")
features = args.feature # e.g. ['f', 'a', 't']
assert len(features) == 3, 'Wrong number of features'
end2end_model = E2E_Hierarchical_MCTN_Model(configs,
features,
feats_dict)
print("PREP FOR TRAINING...")
filename = '_'.join(args.feature) + "_attention_seq2seq_" + \
str("bi_directional" if configs['translation1']['is_bidirectional']
else '') + \
"_bimodal.h5"
output_dir = configs['general']['output_dir']
weights_path = os.path.join(output_dir, filename)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
callbacks = [
EarlyStopping(monitor='val_loss', patience=args.train_patience, verbose=0),
ModelCheckpoint(weights_path, monitor='val_loss', save_best_only=True,
verbose=1),
]
print("TRAINING NOW...")
end2end_model.train(weights_path=weights_path,
n_epochs=args.train_epoch,
val_split=args.val_split,
batch_size=args.batch_size,
callbacks=callbacks)
print("PREDICTING...")
predictions = end2end_model.predict()
predictions = predictions.reshape(-1, )
get_preds_statistics(predictions, feats_dict['test_labels'])