-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
130 lines (108 loc) · 5.16 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
from transformers import (
WhisperConfig,
WhisperFeatureExtractor,
WhisperTokenizer,
WhisperProcessor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer)
# # local
from multiple_datasets.utils import show_argparse
from multiple_datasets.dataset_utils import (
merge_datasets,
KEEP_CHARS)
from multiple_datasets.evaluate_utils import evaluate_and_save, get_compute_metrics_func
from multiple_datasets.data_collators import DataCollatorSpeechSeq2SeqWithPadding
from multiple_datasets.hub_default_utils import push_to_hub_using_whisper_template
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train_datasets', default=None, help='dataset|config|splits,dataset|config|splits')
parser.add_argument('--eval_datasets', default=None, help='dataset|config|splits,dataset|config|splits')
parser.add_argument('--interleave', action='store_true', help='')
parser.add_argument('--whisper-size', default='small')
parser.add_argument('--language', default='mn,Mongolian', help='acronym,Full Language Name')
parser.add_argument('--keep-chars', default=KEEP_CHARS, help='characters that would stay during preprocessing')
parser.add_argument('--train-batch-size', default=32, type=int)
parser.add_argument('--eval-batch-size', default=16, type=int)
parser.add_argument('--max-steps', default=1000, type=int)
parser.add_argument('--num-workers', default=8, type=int)
parser.add_argument('--version', default=1, type=int)
# for reading and writing preprocessed dataset
parser.add_argument('--hf-username', type=str, required=True)
parser.add_argument('--use-cached-ds', action='store_true', help='if passed, it will try to read from preprocessed dataset handle')
parser.add_argument('--merge-audio-to-max', action='store_true', help='if passed, then it will merge audios to `dataset_utils.MAX_AUDIO_DURATION`')
# Trainer.train()
parser.add_argument('--resume-from-checkpoint', action='store_true', help='if passed, training will start from the latest checkpoint')
args = parser.parse_args()
show_argparse(args)
lan, language = args.language.split(',')
model_name = f'openai/whisper-{args.whisper_size}'
output_dir = f"whisper-{args.whisper_size}-{lan}-{args.version}"
print('model_name:', model_name)
print('output_dir:', output_dir)
## Load
config = WhisperConfig.from_pretrained(model_name)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
tokenizer = WhisperTokenizer.from_pretrained(model_name, language=language, task="transcribe")
processor = WhisperProcessor.from_pretrained(model_name, language=language, task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(model_name)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False # not compatible with gradient checkpointing
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
compute_metrics = get_compute_metrics_func(tokenizer)
## Preprocess
train_ds = merge_datasets(
args.train_datasets, args.interleave,
args.keep_chars, feature_extractor, tokenizer,
args.hf_username, args.use_cached_ds, args.merge_audio_to_max)
eval_ds = merge_datasets(
args.eval_datasets, False,
args.keep_chars, feature_extractor, tokenizer,
args.hf_username, args.use_cached_ds, args.merge_audio_to_max)
# Train
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir, # change to a repo name of your choice
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=args.max_steps,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
remove_unused_columns=False, # important when we use set_transform
save_total_limit=5,
#
dataloader_num_workers=args.num_workers
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
try:
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
except KeyboardInterrupt:
print('KEYBOARD INTERRUPTED! Starting evaluation with current state')
trainer.is_in_train = False
metrics = evaluate_and_save(trainer, tokenizer, feature_extractor)
push_to_hub_using_whisper_template(
args.train_datasets, args.hf_username, metrics, lan, output_dir
)