-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
22 lines (19 loc) · 1.09 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torch.autograd import Variable
from torchvision.utils import save_image
def sample_images(generator, test_dataloader, args, epoch, batches_done):
"""Saves a generated sample from the validation set"""
imgs = next(iter(test_dataloader))
real_A = Variable(imgs['A'].type(torch.FloatTensor).cuda())
real_B = Variable(imgs['B'].type(torch.FloatTensor).cuda())
fake_B = generator(real_A)
img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
save_image(img_sample, '%s-%s/%s/%s-%s.png' % (args.exp_name, args.dataset_name, args.img_result_dir, batches_done, epoch), nrow=5, normalize=True)
class LambdaLR():
def __init__(self, epoch_num, epoch_start, decay_start_epoch):
assert ((epoch_num - decay_start_epoch) > 0), "Decay must start before the training session ends!"
self.epoch_num = epoch_num
self.epoch_start = epoch_start
self.decay_start_epoch = decay_start_epoch
def step(self, epoch):
return 1.0 - max(0, epoch + 1 + self.epoch_start - self.decay_start_epoch)/(self.epoch_num - self.decay_start_epoch)