-
Notifications
You must be signed in to change notification settings - Fork 53
/
run_eval.py
376 lines (326 loc) · 13.2 KB
/
run_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
#!/usr/bin/env python
# coding=utf-8
# Copyright BigScience, The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reproduce the main evaluation in `Multitask Prompted Training Enables Zero-Shot Task Generalization` using PyTorch.
This script is heavily adapted from https://github.com/huggingface/transformers/blob/7533d30acd975027e83a548e4c38e06fa335291b/examples/pytorch/multiple-choice/run_swag_no_trainer.py
"""
import argparse
import logging
import os
import random
import json
import datasets
import torch
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from transformers import (
AutoConfig,
AutoTokenizer,
default_data_collator,
)
from promptsource.templates import DatasetTemplates
from t0.data_collator import DataCollatorForMultipleChoice
from t0.model import ModelBase
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Reproduce main evaluation in T0.")
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to use (via the datasets library).",
required=True,
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The configuration name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--template_name",
type=str,
default=None,
help="The template/prompt name",
required=True,
)
parser.add_argument(
"--max_length",
type=int,
default=1024,
help=(
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
),
)
parser.add_argument(
"--target_max_length",
type=int,
default=256,
help="Target max length. Sequences longer than this will be truncated."
)
parser.add_argument(
"--pad_to_max_length",
action="store_true",
help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models. The list of T0 variants can be found on `https://huggingface.co/bigscience/T0_3B`",
required=True,
)
parser.add_argument(
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Where to store the final model."
)
parser.add_argument(
"--debug",
action="store_true",
help="Activate debug mode and run training only with a subset of data.",
)
parser.add_argument(
"--parallelize",
action="store_true",
help=(
"If passed, will call `model.parallelize` which splits the model on all GPUs available when applicable (model parallelism). "
"Note that this feature is still experimental in HF Transformers."
),
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us.
accelerator = Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state)
# Setup logging, we only want one process per machine to log things on the screen.
# accelerator.is_local_main_process is only True for one process per machine.
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Handle the output directory creation
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# In distributed evaluation, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
if args.dataset_name == "anli":
error_message = "For ANLI, `dataset_config_name` should be either `dev_r1`, `dev_r2` or `dev_r3`."
assert args.dataset_config_name is not None, error_message
assert args.dataset_config_name in ["dev_r1", "dev_r2", "dev_r3"], error_message
raw_datasets = load_dataset(args.dataset_name, split=args.dataset_config_name)
else:
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name, split="validation")
#TODO(Victor): enable loading pre-processed dataset from https://huggingface.co/datasets/bigscience/P3
# Trim a number of evaluation examples
if args.debug:
raw_datasets = raw_datasets.select(range(min(len(raw_datasets),100)))
column_names = raw_datasets.column_names
# Load pretrained model and tokenizer
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if args.config_name:
config = AutoConfig.from_pretrained(args.config_name)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(args.model_name_or_path)
else:
raise ValueError(
"Either `args.config_name` or `args.model_name_or_path` should be provided."
)
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, padding_side="left")
elif args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, padding_side="left")
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if tokenizer.pad_token is None:
for token in [tokenizer.eos_token, tokenizer.bos_token, tokenizer.sep_token]:
if token is not None:
tokenizer.pad_token = token
if tokenizer.pad_token is None:
raise ValueError("Please define a pad token id.")
model = ModelBase.from_config(
config=config,
model_name_or_path=args.model_name_or_path,
parallelize=args.parallelize
)
# Preprocessing the datasets.
# First we tokenize all the texts.
padding = "max_length" if args.pad_to_max_length else False
# Get the prompt to apply and the possible targets.
# TODO(Victor): If pulling from pre-processed data, remove this logic.
prompts = DatasetTemplates(
f"{args.dataset_name}"
if args.dataset_config_name is None
else f"{args.dataset_name}/{args.dataset_config_name}"
)
template = prompts[args.template_name]
def preprocess_function(examples):
bs = len(examples[column_names[0]])
input_texts = []
target_texts = []
answer_choices_texts = []
for i in range(bs):
ex = {
k: examples[k][i]
for k in column_names
}
input, target = template.apply(ex)
ex_answer_choices = template.get_answer_choices_list(ex)
assert target in ex_answer_choices
input_texts.append(input)
target_texts.append(target)
answer_choices_texts.append(ex_answer_choices)
tokenized_inputs = tokenizer(
input_texts,
padding=padding,
max_length=args.max_length,
truncation=True,
add_special_tokens=False,
)
tokenized_targets = [
tokenizer(
ans_choi,
# padding is on the right here.
padding=False,
max_length=args.max_length,
truncation=True,
)
for ans_choi in answer_choices_texts
]
features = {
k: [
[elem for _ in range(len(tokenized_targets[idx]["input_ids"]))]
for idx, elem in enumerate(v)
]
for k, v in tokenized_inputs.items()
}
features["labels"] = [
tokenized_targets[idx]["input_ids"]
for idx in range(bs)
]
features["labels_attention_mask"] = [
tokenized_targets[idx]["attention_mask"]
for idx in range(bs)
]
features["targets"] = [
answer_choices_texts[idx].index(t)
for idx, t in enumerate(target_texts)
]
return features
with accelerator.main_process_first():
eval_dataset = raw_datasets.map(
preprocess_function, batched=True, remove_columns=column_names
)
# Log a few random samples from the eval set:
for index in random.sample(range(len(eval_dataset)), 3):
logger.info(f"Sample {index} of the training set: {eval_dataset[index]}.")
# DataLoaders creation:
if args.pad_to_max_length:
# If padding was already done ot max length, we use the default data collator that will just convert everything
# to tensors.
data_collator = default_data_collator
else:
# Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
# the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
# of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
data_collator = DataCollatorForMultipleChoice(
tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
# Use the device given by the `accelerator` object.
if not args.parallelize:
model.to(accelerator.device)
# Prepare everything with our `accelerator`.
eval_dataloader = accelerator.prepare(eval_dataloader)
# Metrics
metric = load_metric("accuracy")
# Eval!
total_batch_size = args.per_device_eval_batch_size * accelerator.num_processes
logger.info("***** Running evaluation *****")
logger.info(f" Num examples = {len(eval_dataset)}")
logger.info(f" Instantaneous batch size per device = {args.per_device_eval_batch_size}")
logger.info(f" Total eval batch size (w. parallel, distributed) = {total_batch_size}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(len(eval_dataloader)), disable=not accelerator.is_local_main_process)
model.eval()
for batch in eval_dataloader:
with torch.no_grad():
predictions = model(batch)
metric.add_batch(
predictions=accelerator.gather(predictions),
references=accelerator.gather(batch["targets"]),
)
progress_bar.update(1)
eval_metric = metric.compute()
accelerator.print(f"Result: {eval_metric}")
results = {
"dataset_name": args.dataset_name,
"dataset_config_name": args.dataset_config_name,
"template_name": args.template_name,
"evaluation": eval_metric
}
if accelerator.is_main_process:
if args.output_dir is not None:
with open(os.path.join(args.output_dir, "results.json"), "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
main()