-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_runner.py
80 lines (66 loc) · 2.86 KB
/
train_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
from jax import random
from experiments import Experiment
import os
import yaml
if(__name__ == '__main__'):
# Load the command line arguments
parser = argparse.ArgumentParser(description='Train an NIF model')
parser.add_argument('--name',
action='store',
type=str,
help='Name of model. Is used to load existing checkpoints.',
default='GLOW')
parser.add_argument('--quantize',
action='store',
type=int,
help='The number of bits to use in quantization',
default=3)
parser.add_argument('--start_it',
action='store',
type=int,
help='Sets the training iteration to start on. -1 finds most recent',
default=-1)
parser.add_argument('--checkpoint_interval',
action='store',
type=int,
help='Sets the number of iterations between each test',
default=5000)
parser.add_argument('--experiment_root',
action='store',
type=str,
help='The root directory of the experiments folder',
default='Experiments')
parser.add_argument('--experiment_def_path',
action='store',
type=str,
help='The root directory of the experiments definitions folder',
default='experiment_definitions')
parser.add_argument('--optimizer_settings',
action='store',
type=str,
help='Settings for the optimizer',
default='adam_warmup')
args = parser.parse_args()
# Load the experiment object
exp = Experiment(args.name,
args.quantize,
args.checkpoint_interval,
start_it=args.start_it,
experiment_root=args.experiment_root)
# Initialize the experiment from scratch or from a checkpoint
if(exp.current_iteration is None):
key = random.PRNGKey(0)
# Load the model and optimizer definitions
model_def_path = os.path.join(args.experiment_def_path, args.name+'.yaml')
opt_def_path = os.path.join(args.experiment_def_path, args.optimizer_settings+'.yaml')
with open(model_def_path) as f:
model_meta_data = yaml.safe_load(f)
with open(opt_def_path) as f:
opt_meta_data = yaml.safe_load(f)
# Initialize the experiment
exp.create_experiment_from_meta_data(key, model_meta_data, opt_meta_data)
else:
exp.load_experiment()
# Train
exp.train()