-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerator.py
121 lines (92 loc) · 3.84 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import torch.nn.functional as F
from res_stack import ResStack
# from res_stack import ResStack
MAX_WAV_VALUE = 32768.0
class Generator(nn.Module):
def __init__(self, mel_channel):
super(Generator, self).__init__()
self.mel_channel = mel_channel
self.generator = nn.Sequential(
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1)),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4)),
ResStack(256),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=4)),
ResStack(128),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1)),
ResStack(64),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)),
ResStack(32),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1)),
nn.Tanh(),
)
"""self.generator = nn.Sequential(
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(mel_channel, 256, kernel_size=7, stride=1)),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(256, 256, kernel_size=16, stride=8, padding=4)),
ResStack(256),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=4)),
ResStack(128),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1)),
ResStack(64),
nn.utils.weight_norm(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)),
ResStack(32),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1)),
nn.Tanh(),
)"""
def forward(self, mel):
mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram
return self.generator(mel)
def eval(self, inference=False):
super(Generator, self).eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
def remove_weight_norm(self):
for idx, layer in enumerate(self.generator):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except:
layer.remove_weight_norm()
def inference(self, mel):
hop_length = 256
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device)
mel = torch.cat((mel, zero), dim=2)
audio = self.forward(mel)
audio = audio.squeeze() # collapse all dimension except time axis
audio = audio[:-(hop_length*10)]
audio = MAX_WAV_VALUE * audio
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
audio = audio.short()
return audio
'''
to run this, fix
from . import ResStack
into
from res_stack import ResStack
'''
if __name__ == '__main__':
model = Generator(80)
x = torch.randn(3, 80, 10)
print(x.shape)
y = model(x)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)