diff --git a/config/default.yaml b/config/default.yaml index 4152ec4..5c3771d 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -7,8 +7,10 @@ train: num_workers: 32 batch_size: 16 optimizer: 'adam' + epochs: 1500 adam: - lr: 0.0001 + init_lr: 0.0001 + final_lr: 0.00001 beta1: 0.5 beta2: 0.9 --- diff --git a/utils/train.py b/utils/train.py index 19bd3d5..5c1f5a3 100644 --- a/utils/train.py +++ b/utils/train.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import itertools import traceback from model.generator import Generator @@ -13,18 +12,35 @@ from .validation import validate +def cosine_decay(init_val, final_val, step, decay_steps): + alpha = final_val / init_val + cosine_decay = 0.5 * (1 + math.cos(math.pi * step / decay_steps)) + decayed = (1 - alpha) * cosine_decay + alpha + return init_val * decayed + + +def adjust_learning_rate(optimizer, epoch, hp): + init_lr = hp.train.adam.init_lr + final_lr = hp.train.adam.final_lr + decay_steps = hp.train.epochs + lr = cosine_decay(init_lr, final_lr, epoch, decay_steps) + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str): model_g = Generator(hp.audio.n_mel_channels).cuda() model_d = MultiScaleDiscriminator().cuda() optim_g = torch.optim.Adam(model_g.parameters(), - lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) + lr=hp.train.adam.init_lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) optim_d = torch.optim.Adam(model_d.parameters(), - lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) + lr=hp.train.adam.init_lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) githash = get_commit_hash() - init_epoch = -1 + elapsed_epochs = 0 step = 0 if chkpt_path is not None: @@ -35,7 +51,7 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, optim_g.load_state_dict(checkpoint['optim_g']) optim_d.load_state_dict(checkpoint['optim_d']) step = checkpoint['step'] - init_epoch = checkpoint['epoch'] + elapsed_epochs = checkpoint['epoch'] if hp_str != checkpoint['hp_str']: logger.warning("New hparams is different from checkpoint. Will use new.") @@ -54,11 +70,18 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, try: model_g.train() model_d.train() - for epoch in itertools.count(init_epoch+1): + + epochs = hp.train.epochs - elapsed_epochs + for epoch in range(epochs + 1): if epoch % hp.log.validation_interval == 0: with torch.no_grad(): validate(hp, args, model_g, model_d, valloader, writer, step) + epoch += elapsed_epochs + + adjust_learning_rate(optim_g, epoch, hp) + adjust_learning_rate(optim_d, epoch, hp) + trainloader.dataset.shuffle_mapping() loader = tqdm.tqdm(trainloader, desc='Loading train data') for (melG, audioG), (melD, audioD) in loader: