-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathtrain.py
387 lines (305 loc) · 12.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Trains, evaluates and saves the model network using a queue."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import imp
import json
import logging
import numpy as np
import os.path
import sys
# configure logging
if 'TV_IS_DEV' in os.environ and os.environ['TV_IS_DEV']:
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
else:
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
import time
from shutil import copyfile
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
import tensorvision.utils as utils
import tensorvision.core as core
flags = tf.app.flags
FLAGS = flags.FLAGS
def _copy_parameters_to_traindir(hypes, input_file, target_name, target_dir):
"""
Helper to copy files defining the network to the saving dir.
Parameters
----------
input_file : str
name of source file
target_name : str
target name
traindir : str
directory where training data is saved
"""
target_file = os.path.join(target_dir, target_name)
input_file = os.path.join(hypes['dirs']['base_path'], input_file)
copyfile(input_file, target_file)
def _start_enqueuing_threads(hypes, q, sess, data_input):
"""
Start the enqueuing threads of the data_input module.
Parameters
----------
hypes : dict
Hyperparameters
sess : session
q : queue
data_input: data_input
"""
with tf.name_scope('data_load'):
data_input.start_enqueuing_threads(hypes, q['train'], 'train',
sess, hypes['dirs']['data_dir'])
data_input.start_enqueuing_threads(hypes, q['val'], 'val', sess,
hypes['dirs']['data_dir'])
def initialize_training_folder(hypes):
"""
Creating the training folder and copy all model files into it.
The model will be executed from the training folder and all
outputs will be saved there.
Parameters
----------
hypes : dict
Hyperparameters
"""
target_dir = os.path.join(hypes['dirs']['output_dir'], "model_files")
if not os.path.exists(target_dir):
os.makedirs(target_dir)
# Creating an additional logging saving the console outputs
# into the training folder
logging_file = os.path.join(hypes['dirs']['output_dir'], "output.log")
filewriter = logging.FileHandler(logging_file, mode='w')
formatter = logging.Formatter(
'%(asctime)s %(name)-3s %(levelname)-3s %(message)s')
filewriter.setLevel(logging.INFO)
filewriter.setFormatter(formatter)
logging.getLogger('').addHandler(filewriter)
# TODO: read more about loggers and make file logging neater.
target_file = os.path.join(target_dir, 'hypes.json')
with open(target_file, 'w') as outfile:
json.dump(hypes, outfile, indent=2)
_copy_parameters_to_traindir(
hypes, hypes['model']['input_file'], "data_input.py", target_dir)
_copy_parameters_to_traindir(
hypes, hypes['model']['architecture_file'], "architecture.py",
target_dir)
_copy_parameters_to_traindir(
hypes, hypes['model']['objective_file'], "objective.py", target_dir)
_copy_parameters_to_traindir(
hypes, hypes['model']['optimizer_file'], "solver.py", target_dir)
def build_training_graph(hypes, modules):
"""
Build the tensorflow graph out of the model files.
Parameters
----------
hypes : dict
Hyperparameters
modules : tuple
The modules load in utils.
Returns
-------
tuple
(q, train_op, loss, eval_lists) where
q is a dict with keys 'train' and 'val' which includes queues,
train_op is a tensorflow op,
loss is a float,
eval_lists is a dict with keys 'train' and 'val'
"""
data_input, arch, objective, solver = modules
global_step = tf.Variable(0.0, trainable=False)
q, logits, decoder, = {}, {}, {}
image_batch, label_batch = {}, {}
eval_lists = {}
# Add Input Producers to the Graph
with tf.name_scope('Input'):
q['train'] = data_input.create_queues(hypes, 'train')
input_batch = data_input.inputs(hypes, q['train'], 'train',
hypes['dirs']['data_dir'])
image_batch['train'], label_batch['train'] = input_batch
logits['train'] = arch.inference(hypes, image_batch['train'], 'train')
decoder['train'] = objective.decoder(hypes, logits['train'])
# Add to the Graph the Ops for loss calculation.
loss = objective.loss(hypes, decoder['train'], label_batch['train'])
# Add to the Graph the Ops that calculate and apply gradients.
train_op = solver.training(hypes, loss, global_step=global_step)
# Add the Op to compare the logits to the labels during evaluation.
eval_lists['train'] = objective.evaluation(hypes, decoder['train'],
label_batch['train'])
# Validation Cycle to the Graph
with tf.name_scope('Validation'):
with tf.name_scope('Input'):
q['val'] = data_input.create_queues(hypes, 'val')
input_batch = data_input.inputs(hypes, q['val'], 'val',
hypes['dirs']['data_dir'])
image_batch['val'], label_batch['val'] = input_batch
tf.get_variable_scope().reuse_variables()
logits['val'] = arch.inference(hypes, image_batch['val'], 'val')
decoder['val'] = objective.decoder(hypes, logits['val'])
eval_lists['val'] = objective.evaluation(hypes, decoder['val'],
label_batch['val'])
return q, train_op, loss, eval_lists
def maybe_download_and_extract(hypes):
"""
Download the data if it isn't downloaded by now.
Parameters
----------
hypes : dict
Hyperparameters
"""
f = os.path.join(hypes['dirs']['base_path'], hypes['model']['input_file'])
data_input = imp.load_source("input", f)
if hasattr(data_input, 'maybe_download_and_extract'):
data_input.maybe_download_and_extract(hypes, hypes['dirs']['data_dir'])
def _write_evaluation_to_summary(evaluation_results, summary_writer, phase,
global_step, sess):
"""
Write the evaluation_results to the summary file.
Parameters
----------
evaluation_results : tuple
The output of do_eval
summary_writer : tf.train.SummaryWriter
phase : string
Name of Operation to write
global_step : tensor or int
Xurrent training step
sess : tf.Session
"""
# write result to summary
summary = tf.Summary()
eval_names, avg_results = evaluation_results
for name, result in zip(eval_names, avg_results):
summary.value.add(tag='Evaluation/' + phase + '/' + name,
simple_value=result)
summary_writer.add_summary(summary, global_step)
def _print_training_status(hypes, step, loss_value, start_time, sess_coll):
duration = (time.time() - start_time) / int(utils.cfg.step_show)
examples_per_sec = hypes['solver']['batch_size'] / duration
sec_per_batch = float(duration)
info_str = utils.cfg.step_str
sess, saver, summary_op, summary_writer, coord, threads = sess_coll
logging.info(info_str.format(step=step,
total_steps=hypes['solver']['max_steps'],
loss_value=loss_value,
sec_per_batch=sec_per_batch,
examples_per_sec=examples_per_sec)
)
# Update the events file.
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, step)
def _write_checkpoint_to_disk(hypes, step, sess_coll):
sess, saver, summary_op, summary_writer, coord, threads = sess_coll
checkpoint_path = os.path.join(hypes['dirs']['output_dir'],
'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
def _do_evaluation(hypes, step, sess_coll, eval_dict):
sess, saver, summary_op, summary_writer, coord, threads = sess_coll
logging.info('Doing Evaluate with Training Data.')
eval_results = core.do_eval(hypes, eval_dict, phase='train',
sess=sess)
_write_evaluation_to_summary(eval_results, summary_writer,
"Train", step, sess)
logging.info('Doing Evaluation with Testing Data.')
eval_results = core.do_eval(hypes, eval_dict, phase='val',
sess=sess)
_write_evaluation_to_summary(eval_results, summary_writer,
'val', step, sess)
def run_training_step(hypes, step, start_time, graph_ops, sess_coll):
"""Run one iteration of training."""
# Unpack operations for later use
sess = sess_coll[0]
q, train_op, loss, eval_dict = graph_ops
# Run the training Step
_, loss_value = sess.run([train_op, loss])
# Write the summaries and print an overview fairly often.
if step % int(utils.cfg.step_show) == 0:
# Print status to stdout.
_print_training_status(hypes, step, loss_value, start_time, sess_coll)
# Reset timer
start_time = time.time()
# Save a checkpoint and evaluate the model periodically.
if (step + 1) % int(utils.cfg.step_eval) == 0 or \
(step + 1) == hypes['solver']['max_steps']:
# write checkpoint to disk
_write_checkpoint_to_disk(hypes, step, sess_coll)
# Reset timer
start_time = time.time()
# Do a evaluation and print the current state
if (step + 1) % int(utils.cfg.step_eval) == 0 or \
(step + 1) == hypes['solver']['max_steps']:
# write checkpoint to disk
_do_evaluation(hypes, step, sess_coll, eval_dict)
# Reset timer
start_time = time.time()
return start_time
def do_training(hypes):
"""
Train model for a number of steps.
This trains the model for at most hypes['solver']['max_steps'].
It shows an update every utils.cfg.step_show steps and writes
the model to hypes['dirs']['output_dir'] every utils.cfg.step_eval
steps.
Paramters
---------
hypes : dict
Hyperparameters
"""
# Get the sets of images and labels for training, validation, and
# test on MNIST.
modules = utils.load_modules_from_hypes(hypes)
data_input, arch, objective, solver = modules
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# build the graph based on the loaded modules
graph_ops = build_training_graph(hypes, modules)
q = graph_ops[0]
# prepaire the tv session
sess_coll = core.start_tv_session(hypes)
sess, saver, summary_op, summary_writer, coord, threads = sess_coll
# Start the data load
_start_enqueuing_threads(hypes, q, sess, data_input)
# And then after everything is built, start the training loop.
start_time = time.time()
for step in xrange(hypes['solver']['max_steps']):
start_time = run_training_step(hypes, step, start_time,
graph_ops, sess_coll)
# stopping input Threads
coord.request_stop()
coord.join(threads)
def main(_):
"""Run main function."""
if FLAGS.hypes is None:
logging.error("No hypes are given.")
logging.error("Usage: tv-train --hypes hypes.json")
exit(1)
if FLAGS.gpus is None:
if 'TV_USE_GPUS' in os.environ:
if os.environ['TV_USE_GPUS'] == 'force':
logging.error('Please specify a GPU.')
logging.error('Usage tv-train --gpus <ids>')
exit(1)
else:
gpus = os.environ['TV_USE_GPUS']
logging.info("GPUs are set to: %s", gpus)
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
else:
logging.info("GPUs are set to: %s", FLAGS.gpus)
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpus
with open(tf.app.flags.FLAGS.hypes, 'r') as f:
logging.info("f: %s", f)
hypes = json.load(f)
utils.load_plugins()
utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)
logging.info("Initialize training folder")
initialize_training_folder(hypes)
maybe_download_and_extract(hypes)
logging.info("Start training")
do_training(hypes)
if __name__ == '__main__':
tf.app.run()