-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathrun_classification.py
97 lines (81 loc) · 4.22 KB
/
run_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
from params import Params
from models import representation as models
from dataset import classification as dataset
from tools import units
from tools.save import save_experiment
import itertools
import argparse
import keras.backend as K
gpu_count = len(units.get_available_gpus())
dir_path,global_logger = units.getLogger()
def run(params,reader):
params=dataset.process_embedding(reader,params)
qdnn = models.setup(params)
model = qdnn.getModel()
model.compile(loss = params.loss,optimizer = units.getOptimizer(name=params.optimizer,lr=params.lr), metrics=['accuracy'])
model.summary()
(train_x, train_y),(test_x, test_y),(val_x, val_y) = reader.get_processed_data()
#pretrain_x, pretrain_y = dataset.get_sentiment_dic_training_data(reader,params)
#model.fit(x=pretrain_x, y = pretrain_y, batch_size = params.batch_size, epochs= 3,validation_data= (test_x, test_y))
history = model.fit(x=train_x, y = train_y, batch_size = params.batch_size, epochs= params.epochs,validation_data= (test_x, test_y))
evaluation = model.evaluate(x = val_x, y = val_y)
save_experiment(model, params, evaluation, history, reader)
#save_experiment(model, params, evaluation, history, reader, config_file)
return history,evaluation
grid_parameters ={
"dataset_name":["MR","TREC","SST_2","SST_5","MPQA","SUBJ","CR"],
"wordvec_path":["glove/glove.6B.50d.txt"],#"glove/glove.6B.300d.txt"],"glove/normalized_vectors.txt","glove/glove.6B.50d.txt","glove/glove.6B.100d.txt",
"loss": ["categorical_crossentropy"],#"mean_squared_error"],,"categorical_hinge"
"optimizer":["rmsprop"], #"adagrad","adamax","nadam"],,"adadelta","adam"
"batch_size":[16],#,32
"activation":["sigmoid"],
"amplitude_l2":[0], #0.0000005,0.0000001,
"phase_l2":[0.00000005],
"dense_l2":[0],#0.0001,0.00001,0],
"measurement_size" :[1400,1600,1800,2000],#,50100],
"lr" : [0.1],#,1,0.01
"dropout_rate_embedding" : [0.9],#0.5,0.75,0.8,0.9,1],
"dropout_rate_probs" : [0.9]#,0.5,0.75,0.8,1]
}
grid_parameters ={
"dataset_name":["SST_2"],
"wordvec_path":["glove/glove.6B.50d.txt"],#"glove/glove.6B.300d.txt"],"glove/normalized_vectors.txt","glove/glove.6B.50d.txt","glove/glove.6B.100d.txt",
"loss": ["categorical_crossentropy"],#"mean_squared_error"],,"categorical_hinge"
"optimizer":["rmsprop"], #"adagrad","adamax","nadam"],,"adadelta","adam"
"batch_size":[16],#,32
"activation":["sigmoid"],
"amplitude_l2":[0], #0.0000005,0.0000001,
"phase_l2":[0.00000005],
"dense_l2":[0],#0.0001,0.00001,0],
"measurement_size" :[30],#,50100],
"lr" : [0.1],#,1,0.01
"dropout_rate_embedding" : [0.9],#0.5,0.75,0.8,0.9,1],
"dropout_rate_probs" : [0.9],#,0.5,0.75,0.8,1] ,
"ablation" : [1],
# "network_type" : ["ablation"]
}
if __name__=="__main__":
# import argparse
parser = argparse.ArgumentParser(description='running the complex embedding network')
parser.add_argument('-gpu_num', action = 'store', dest = 'gpu_num', help = 'please enter the gpu num.',default=gpu_count)
parser.add_argument('-gpu', action = 'store', dest = 'gpu', help = 'please enter the gpu num.',default=0)
args = parser.parse_args()
parameters= [arg for index,arg in enumerate(itertools.product(*grid_parameters.values())) if index%args.gpu_num==args.gpu]
parameters= parameters[::-1]
params = Params()
config_file = 'config/qdnn.ini' # define dataset in the config
params.parse_config(config_file)
for parameter in parameters:
old_dataset = params.dataset_name
params.setup(zip(grid_parameters.keys(),parameter))
if old_dataset != params.dataset_name:
print("switch {} to {}".format(old_dataset,params.dataset_name))
reader=dataset.setup(params)
params.reader = reader
# params.print()
# dir_path,logger = units.getLogger()
# params.save(dir_path)
history,evaluation=run(params,reader)
global_logger.info("{} : {:.4f} ".format( params.to_string() ,max(history.history["val_acc"])))
K.clear_session()