-
Notifications
You must be signed in to change notification settings - Fork 93
/
train_model.py
153 lines (127 loc) · 5.25 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import mxnet as mx
import logging
import os
from center_loss import *
def get_model_dict(network, data_shape):
'''
return the (name,shape) dict for both args and aux,
so that in the finetune process, new model will only load
those valid params
'''
arg_shapes, output_shapes, aux_shapes = network.infer_shape( data=(1,)+data_shape )
arg_names = network.list_arguments()
aux_names = network.list_auxiliary_states()
arg_dict = dict(zip(arg_names, arg_shapes))
aux_dict = dict(zip(aux_names, aux_shapes))
return arg_dict, aux_dict
def fit(args, network, data_loader, data_shape, batch_end_callback=None, patterns=None, initializers=None):
# kvstore
kv = mx.kvstore.create(args.kv_store)
# logging
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
if 'log_file' in args and args.log_file is not None:
log_file = args.log_file
log_dir = args.log_dir
log_file_full_name = os.path.join(log_dir, log_file)
if not os.path.exists(log_dir):
os.mkdir(log_dir)
logger = logging.getLogger()
handler = logging.FileHandler(log_file_full_name)
formatter = logging.Formatter(head)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
logger.info('start with arguments %s', args)
else:
logging.basicConfig(level=logging.DEBUG, format=head)
logging.info('start with arguments %s', args)
# load model
model_prefix = args.model_prefix
model_args = {}
if args.load_epoch is not None:
assert model_prefix is not None
tmp = mx.model.FeedForward.load(model_prefix, args.load_epoch)
# only add those with the same shape
arg_dict, aux_dict = get_model_dict( network, data_shape )
valid_arg = dict()
valid_aux = dict()
# print all the parameters
print 'all params ', arg_dict
# for args
for k, v in arg_dict.items():
# skip those 'label'
if k == 'data' or k.endswith('label'):
continue
# skip those pretrain model dosen't have
if not k in tmp.arg_params.keys():
continue
if v == tmp.arg_params[k].shape:
valid_arg[k] = tmp.arg_params[k]
print 'catching arg: {} from pretrained model'.format(k)
# for aux
for k, v in aux_dict.items():
# skip these 'label'
if k == 'data' or k.endswith('label'):
continue
# skip those pretrain model dosen't have
if not k in tmp.aux_params.keys():
continue
if v == tmp.aux_params[k].shape:
valid_aux[k] = tmp.aux_params[k]
print 'catching aux: {} from pretrained model'.format(k)
model_args = {'arg_params' : valid_arg,
'aux_params' : valid_aux,
'begin_epoch' : args.load_epoch}
# save model
save_model_prefix = args.save_model_prefix
if save_model_prefix is None:
save_model_prefix = model_prefix
checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)
# data
(train, val) = data_loader
# train
devs = mx.cpu() if args.gpus is '' else [
mx.gpu(int(i)) for i in args.gpus.split(',')]
epoch_size = args.num_examples / args.batch_size
if args.kv_store == 'dist_sync':
epoch_size /= kv.num_workers
model_args['epoch_size'] = epoch_size
if 'lr_factor' in args and args.lr_factor < 1:
model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
step = max(int(epoch_size * args.lr_factor_epoch), 1),
factor = args.lr_factor)
if 'clip_gradient' in args and args.clip_gradient is not None:
model_args['clip_gradient'] = args.clip_gradient
# disable kvstore for single device
if 'local' in kv.type and (
args.gpus is None or len(args.gpus.split(',')) is 1):
kv = None
init_patterns = ['.*fc.*', '.*']
init_methods = [ mx.init.Normal(sigma=0.001), mx.init.Xavier(factor_type="out", rnd_type="gaussian", magnitude=2.0)]
print 'dev is ',devs
model = mx.model.FeedForward(
ctx = devs,
symbol = network,
num_epoch = args.num_epochs,
learning_rate = args.lr,
momentum = 0.9,
wd = 0.0005,
initializer = mx.init.Mixed(init_patterns, init_methods),
**model_args)
if batch_end_callback is not None:
if not isinstance(batch_end_callback, list):
batch_end_callback = [batch_end_callback]
else:
batch_end_callback = []
batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
# custom metric
eval_metrics = mx.metric.CompositeEvalMetric()
eval_metrics.add(Accuracy())
eval_metrics.add(CenterLossMetric())
model.fit(
X = train,
eval_metric = eval_metrics,
eval_data = val,
kvstore = kv,
batch_end_callback = batch_end_callback,
epoch_end_callback = checkpoint)