forked from fatchord/WaveRNN
-
Notifications
You must be signed in to change notification settings - Fork 113
/
Copy pathtrain_forward.py
92 lines (78 loc) · 3.66 KB
/
train_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import itertools
import os
import subprocess
from pathlib import Path
from typing import Union
import torch
from torch import optim
from torch.utils.data.dataloader import DataLoader
from models.fast_pitch import FastPitch
from models.forward_tacotron import ForwardTacotron
from trainer.common import to_device
from trainer.forward_trainer import ForwardTrainer
from trainer.multi_forward_trainer import MultiForwardTrainer
from utils.checkpoints import restore_checkpoint, init_tts_model
from utils.dataset import get_forward_dataloaders
from utils.display import *
from utils.dsp import DSP
from utils.files import read_config
from utils.paths import Paths
def try_get_git_hash() -> Union[str, None]:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
except Exception as e:
print(f'Could not retrieve git hash! {e}')
return None
def create_gta_features(model: Union[ForwardTacotron, FastPitch],
train_set: DataLoader,
val_set: DataLoader,
save_path: Path) -> None:
model.eval()
device = next(model.parameters()).device # use same device as model parameters
iters = len(train_set) + len(val_set)
dataset = itertools.chain(train_set, val_set)
for i, batch in enumerate(dataset, 1):
batch = to_device(batch, device=device)
with torch.no_grad():
pred = model(batch)
gta = pred['mel_post'].cpu().numpy()
for j, item_id in enumerate(batch['item_id']):
mel = gta[j][:, :batch['mel_len'][j]]
np.save(str(save_path/f'{item_id}.npy'), mel, allow_pickle=False)
bar = progbar(i, iters)
msg = f'{bar} {i}/{iters} Batches '
stream(msg)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train ForwardTacotron TTS')
parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features')
parser.add_argument('--config', metavar='FILE', default='configs/singlespeaker.yaml', help='The config containing all hyperparams.')
args = parser.parse_args()
config = read_config(args.config)
if 'git_hash' not in config or config['git_hash'] is None:
config['git_hash'] = try_get_git_hash()
dsp = DSP.from_config(config)
paths = Paths(config['data_path'], config['tts_model_id'])
assert len(os.listdir(paths.alg)) > 0, f'Could not find alignment files in {paths.alg}, please predict ' \
f'alignments first with python train_tacotron.py --force_align!'
force_gta = args.force_gta
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Using device:', device)
# Instantiate Forward TTS Model
model = init_tts_model(config).to(device)
print(f'\nInitialized tts model: {model}\n')
optimizer = optim.Adam(model.parameters())
restore_checkpoint(model=model, optim=optimizer,
path=paths.forward_checkpoints / 'latest_model.pt',
device=device)
if force_gta:
print('Creating Ground Truth Aligned Dataset...\n')
train_set, val_set = get_forward_dataloaders(
paths=paths, batch_size=8, **config['training']['filter'])
create_gta_features(model, train_set, val_set, paths.gta)
elif config['tts_model'] in ['multi_forward_tacotron', 'multi_fast_pitch']:
trainer = MultiForwardTrainer(paths=paths, dsp=dsp, config=config)
trainer.train(model, optimizer)
else:
trainer = ForwardTrainer(paths=paths, dsp=dsp, config=config)
trainer.train(model, optimizer)