Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dynamic learning rate decay for convergence #39

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ train:
num_workers: 32
batch_size: 16
optimizer: 'adam'
epochs: 1500
adam:
lr: 0.0001
init_lr: 0.0001
final_lr: 0.00001
beta1: 0.5
beta2: 0.9
---
Expand Down
35 changes: 29 additions & 6 deletions utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import traceback

from model.generator import Generator
Expand All @@ -13,18 +12,35 @@
from .validation import validate


def cosine_decay(init_val, final_val, step, decay_steps):
alpha = final_val / init_val
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be init_val / final_val?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The learning rate decays. You might write a demo for testing.

init_val = 1e-4
final_val = 1e-5

Copy link

@casper-hansen casper-hansen Jan 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the following source it's "Minimum learning rate value as a fraction of learning_rate."
https://docs.w3cub.com/tensorflow~python/tf/train/cosine_decay/

Given the values, it looks like it's correct. The naming is just off - it should be the smallest value in the numerator and largest value in the denominator.

cosine_decay = 0.5 * (1 + math.cos(math.pi * step / decay_steps))
decayed = (1 - alpha) * cosine_decay + alpha
return init_val * decayed


def adjust_learning_rate(optimizer, epoch, hp):
init_lr = hp.train.adam.init_lr
final_lr = hp.train.adam.final_lr
decay_steps = hp.train.epochs
lr = cosine_decay(init_lr, final_lr, epoch, decay_steps)

for param_group in optimizer.param_groups:
param_group['lr'] = lr


def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str):
model_g = Generator(hp.audio.n_mel_channels).cuda()
model_d = MultiScaleDiscriminator().cuda()

optim_g = torch.optim.Adam(model_g.parameters(),
lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
lr=hp.train.adam.init_lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
optim_d = torch.optim.Adam(model_d.parameters(),
lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
lr=hp.train.adam.init_lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))

githash = get_commit_hash()

init_epoch = -1
elapsed_epochs = 0
step = 0

if chkpt_path is not None:
Expand All @@ -35,7 +51,7 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp,
optim_g.load_state_dict(checkpoint['optim_g'])
optim_d.load_state_dict(checkpoint['optim_d'])
step = checkpoint['step']
init_epoch = checkpoint['epoch']
elapsed_epochs = checkpoint['epoch']

if hp_str != checkpoint['hp_str']:
logger.warning("New hparams is different from checkpoint. Will use new.")
Expand All @@ -54,11 +70,18 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp,
try:
model_g.train()
model_d.train()
for epoch in itertools.count(init_epoch+1):

epochs = hp.train.epochs - elapsed_epochs
for epoch in range(epochs + 1):
if epoch % hp.log.validation_interval == 0:
with torch.no_grad():
validate(hp, args, model_g, model_d, valloader, writer, step)

epoch += elapsed_epochs

adjust_learning_rate(optim_g, epoch, hp)
adjust_learning_rate(optim_d, epoch, hp)

trainloader.dataset.shuffle_mapping()
loader = tqdm.tqdm(trainloader, desc='Loading train data')
for (melG, audioG), (melD, audioD) in loader:
Expand Down