-
Notifications
You must be signed in to change notification settings - Fork 1
/
runner.py
76 lines (67 loc) · 2.98 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse
import os
import tensorflow as tf
from datetime import datetime
from hparams import hparams
from model.tacotron import Tacotron
'''
Assumed directory structure:
The input directory houses the pre-processed
data and is considered temporary or at least expandable.
Under the input directory is a list of sub-directories
identified by the dataset-name.
/home/<user>/<input_base_dir>/
<dataset_1>/
*.npy
training.txt
...
The output directory houses model meta, checkpoints and logs
and is not temporary since it stores valuable output. Again,
the ouput directory has a list of sub-directories, here each one
represents a model.
/home/<user>/<output_base_dir>/
<model_1>
logs/
...
model/
checkpoints
model meta
samples/
Synthesized training samples
alignment graphs
...
'''
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', help='Absolute path to base preprocessed data directory'),
parser.add_argument('--output_dir', help='Absolute path to the base output directory')
parser.add_argument('--dataset_name', help='The given dataset has to exist in the given input directory')
parser.add_argument('--model_name',
help='name of model to be trained on, defaults to main. This is just used for file-keeping.',
default='main')
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument('--restore_step', type=int, help='Global step to restore from checkpoint.')
parser.add_argument('--summary_interval', type=int, default=100,
help='Steps between running summary ops.')
parser.add_argument('--checkpoint_interval', type=int, default=1000,
help='Steps between writing checkpoints.')
parser.add_argument('--msg_interval', type=int, default=100, help='Interval of general training messages')
# used for broadcasting training updates to slack.
parser.add_argument('--slack_url', help='Slack webhook URL to get periodic reports.')
parser.add_argument('--tf_log_level', type=int, default=1, help='Tensorflow C++ log level.')
args = parser.parse_args()
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
args.in_dir = os.path.join(args.input_dir, args.dataset_name)
args.out_dir = os.path.join(args.output_dir, args.model_name)
args.log_dir = os.path.join(args.out_dir, 'logs')
args.meta_dir = os.path.join(args.out_dir, 'meta')
args.sample_dir = os.path.join(args.out_dir, 'samples')
# create the output directories if needed
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.meta_dir, exist_ok=True)
os.makedirs(args.sample_dir, exist_ok=True)
hparams.parse(args.hparams)
model = Tacotron(hparams)
model.train(args)
main()