From cc1ec7229d735d64cae0d18093092e197b613d69 Mon Sep 17 00:00:00 2001 From: begeekmyfriend Date: Thu, 19 Dec 2019 10:34:55 +0800 Subject: [PATCH 1/2] Use dynamic learning rate decay for convergence Signed-off-by: begeekmyfriend --- config/default.yaml | 4 +++- utils/train.py | 34 ++++++++++++++++++++++++++++------ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/config/default.yaml b/config/default.yaml index 4152ec4..5cdb620 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -7,8 +7,10 @@ train: num_workers: 32 batch_size: 16 optimizer: 'adam' + epochs: 800 adam: - lr: 0.0001 + init_lr: 0.0001 + final_lr: 0.00001 beta1: 0.5 beta2: 0.9 --- diff --git a/utils/train.py b/utils/train.py index 19bd3d5..5ce17b1 100644 --- a/utils/train.py +++ b/utils/train.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import itertools import traceback from model.generator import Generator @@ -13,18 +12,35 @@ from .validation import validate +def cosine_decay(init_val, final_val, step, decay_steps): + alpha = final_val / init_val + cosine_decay = 0.5 * (1 + math.cos(math.pi * step / decay_steps)) + decayed = (1 - alpha) * cosine_decay + alpha + return init_val * decayed + + +def adjust_learning_rate(optimizer, epoch, hp): + init_lr = hp.train.adam.init_lr + final_lr = hp.train.adam.final_lr + decay_steps = hp.train.epochs + lr = cosine_decay(init_lr, final_lr, epoch, decay_steps) + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str): model_g = Generator(hp.audio.n_mel_channels).cuda() model_d = MultiScaleDiscriminator().cuda() optim_g = torch.optim.Adam(model_g.parameters(), - lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) + lr=hp.train.adam.init_lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) optim_d = torch.optim.Adam(model_d.parameters(), - lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) + lr=hp.train.adam.init_lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) githash = get_commit_hash() - init_epoch = -1 + elapsed_epochs = 0 step = 0 if chkpt_path is not None: @@ -35,7 +51,7 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, optim_g.load_state_dict(checkpoint['optim_g']) optim_d.load_state_dict(checkpoint['optim_d']) step = checkpoint['step'] - init_epoch = checkpoint['epoch'] + elapsed_epochs = checkpoint['epoch'] if hp_str != checkpoint['hp_str']: logger.warning("New hparams is different from checkpoint. Will use new.") @@ -54,11 +70,14 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, try: model_g.train() model_d.train() - for epoch in itertools.count(init_epoch+1): + + epochs = hp.train.epochs - elapsed_epochs + for epoch in range(epochs): if epoch % hp.log.validation_interval == 0: with torch.no_grad(): validate(hp, args, model_g, model_d, valloader, writer, step) + epoch += elapsed_epochs trainloader.dataset.shuffle_mapping() loader = tqdm.tqdm(trainloader, desc='Loading train data') for (melG, audioG), (melD, audioD) in loader: @@ -67,6 +86,9 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, melD = melD.cuda() audioD = audioD.cuda() + adjust_learning_rate(optim_g, epoch, hp) + adjust_learning_rate(optim_d, epoch, hp) + # generator optim_g.zero_grad() fake_audio = model_g(melG)[:, :, :hp.audio.segment_length] From b450a915f52e30edff63c933b2dda2dcc9d55f03 Mon Sep 17 00:00:00 2001 From: begeekmyfriend Date: Fri, 27 Dec 2019 10:35:58 +0800 Subject: [PATCH 2/2] Fix loop counting Signed-off-by: begeekmyfriend --- config/default.yaml | 2 +- utils/train.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/config/default.yaml b/config/default.yaml index 5cdb620..5c3771d 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -7,7 +7,7 @@ train: num_workers: 32 batch_size: 16 optimizer: 'adam' - epochs: 800 + epochs: 1500 adam: init_lr: 0.0001 final_lr: 0.00001 diff --git a/utils/train.py b/utils/train.py index 5ce17b1..5c1f5a3 100644 --- a/utils/train.py +++ b/utils/train.py @@ -72,12 +72,16 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, model_d.train() epochs = hp.train.epochs - elapsed_epochs - for epoch in range(epochs): + for epoch in range(epochs + 1): if epoch % hp.log.validation_interval == 0: with torch.no_grad(): validate(hp, args, model_g, model_d, valloader, writer, step) epoch += elapsed_epochs + + adjust_learning_rate(optim_g, epoch, hp) + adjust_learning_rate(optim_d, epoch, hp) + trainloader.dataset.shuffle_mapping() loader = tqdm.tqdm(trainloader, desc='Loading train data') for (melG, audioG), (melD, audioD) in loader: @@ -86,9 +90,6 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, melD = melD.cuda() audioD = audioD.cuda() - adjust_learning_rate(optim_g, epoch, hp) - adjust_learning_rate(optim_d, epoch, hp) - # generator optim_g.zero_grad() fake_audio = model_g(melG)[:, :, :hp.audio.segment_length]