-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbatch-run.py
139 lines (106 loc) · 5.49 KB
/
batch-run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Use this script to run inference on several gens at the same time. You can
dispatch computations on several gpus or run them all on one.
Note that a number of assumptions are made because it's easier this way.
More flexibility can be added: issues and/or discussions are welcome!
In particular: this script only takes rotowire-folder as input, which is assumed
to have the following structure:
- `models` where trained RG models are stored ([pretrained models](https://dl.orangedox.com/rg-models)).
- `output` where everything created by the script is stored: vocabularies,
training examples, extracted list of mentions, etc.
- `gens` you wish to evaluate all the generated texts found here
"""
from data_utils import prep_generated_data
from run import main as inference_run_main
from utils import Container, grouped
from argparse import ArgumentParser
import multiprocessing as mp
import os
def get_parser():
parser = ArgumentParser(description='Use Information Extractor models, '
'on Rotowire data. This script supports'
' usage in parallel runs.')
group = parser.add_argument_group('Script behavior')
group.add_argument('--test', dest='test', default=False,
action="store_true",
help='use test data instead of validation data')
group.add_argument('--show-correctness', dest="show_correctness",
action='store_true', help="When doing inference, add a "
"sign |RIGHT or |WRONG to "
"generated tuples")
group = parser.add_argument_group('File system')
group.add_argument('--rotowire-folder', dest='rotowire_folder', required=True)
group.add_argument('--vocab-prefix', dest='vocab_prefix', default='',
help='prefix of .dict and .labels files')
group = parser.add_argument_group('Evaluation options')
group.add_argument('--batch-size', dest='batch_size', default=32, type=int,
help='batch size')
group.add_argument('--ignore-idx', dest='ignore_idx', default=None, type=int,
help="The index of NONE label in your .label file")
group.add_argument('--average-func', dest='average_func', default='arithmetic',
choices=['geometric', 'arithmetic'],
help='Use geometric/arithmetic mean to ensemble models')
group = parser.add_argument_group('GPUs options')
group.add_argument('--gpus', dest='gpus', type=int, nargs='+')
group.add_argument('--ckpts-per-gpu', dest='ckpts_per_gpu', type=int,
default=1, help="Number of runs on the same gpu")
group.add_argument('--seed', dest='seed', default=3435, type=int,
help='Random seed')
return parser
def build_container(args, gen_filename, gpu):
# remove .txt from gen_filename to add .h5 or .json later on
filename_pfx = gen_filename[:-4]
return Container(
# args shared by both steps
test=args.test,
vocab_prefix=os.path.join(args.rotowire_folder, 'output', args.vocab_prefix),
# Args for the first step (i.e. running data_utils.py)
gen_fi=os.path.join(args.rotowire_folder, 'gens', gen_filename),
output_fi=os.path.join(args.rotowire_folder, 'output', filename_pfx+'.h5'),
input_path=os.path.join(args.rotowire_folder, 'json'),
# args for the second step (i.e. running run.py)
just_eval=True,
datafile=os.path.join(args.rotowire_folder, 'output', args.vocab_prefix + '.h5'),
preddata=os.path.join(args.rotowire_folder, 'output', filename_pfx+'.h5'),
eval_models=os.path.join(args.rotowire_folder, 'models'),
gpu=gpu,
ignore_idx=args.ignore_idx,
batch_size=args.batch_size,
average_func=args.average_func,
show_correctness=args.show_correctness,
store_results=os.path.join(args.rotowire_folder, filename_pfx + '.json'),
seed=args.seed
)
def single_main(args):
gen_fi = args.pop('gen_fi')
dict_pfx = args.vocab_prefix # don't pop this one
output_fi = args.pop('output_fi')
input_path = args.pop('input_path')
# Run data_utils.py in -mode prep_gen_data
prep_generated_data(gen_fi, dict_pfx, output_fi, path=input_path, test=args.test)
# Run run.py (results are serialized inside the function)
_ = inference_run_main(args.to_namespace())
def main(args=None):
parser = get_parser()
args = parser.parse_args(args) if args else parser.parse_args()
gens_folder = os.path.join(args.rotowire_folder, "gens")
gens = [filename
for filename in os.listdir(gens_folder)
if filename.endswith('.txt')]
group_size = len(args.gpus) * args.ckpts_per_gpu
for grouped_gens in grouped(gens, group_size):
# We build a list of containers, one for each gen.
# Checkpoints are dispatched to gpus, each gpu handling
# 'args.ckpts_per_gpu' checkpoints.
_gpus = [g for g in args.gpus for _ in range(args.ckpts_per_gpu)]
containers = [
build_container(args, gen, gpu)
for gen, gpu in zip(grouped_gens, _gpus)
if gen is not None
]
processes = [mp.Process(target=single_main, args=(container,))
for container in containers]
[p.start() for p in processes]
[p.join() for p in processes]
if __name__ == '__main__':
main()