Skip to content

Commit

Permalink
fix nan loss; add scheduler; edit model
Browse files Browse the repository at this point in the history
  • Loading branch information
wcaine committed Jul 23, 2021
1 parent c72a632 commit d51c3a1
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 64 deletions.
47 changes: 9 additions & 38 deletions contrastive_generator_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,7 @@
import config
from downsample_layer import Downsample
from upsample_layer import Upsample

def normalize(x):
norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2)
out = x.div(norm + 1e-7)
return out

class ResnetBlock(nn.Module):
def __init__(self, features):
super().__init__()
layers = []
for i in range(2):
layers += [
nn.ReflectionPad2d(1),
nn.Conv2d(features, features, kernel_size=3),
nn.InstanceNorm2d(features),
]
if i==0:
layers += [
nn.ReLU(True)
]
self.model = nn.Sequential(*layers)

def forward(self, input):
return input + self.model(input)
from resnet_block import ResnetBlock

class Generator(nn.Module):
def __init__(self, in_channels=3, features=64, residuals=9):
Expand Down Expand Up @@ -106,27 +83,21 @@ def forward(self, input, encode_only=False, patch_ids=None):
for layer_id, layer in enumerate(self.model):
feat = layer(feat)
if layer_id in [0, 4, 8, 12, 16]:
# if layer_id in [0, 4, 10, 14, 18]:
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
if num_patches > 0:
if patch_ids is not None:
patch_id = patch_ids[mlp_id]
else:
patch_id = torch.randperm(feat_reshape.shape[1], device=config.DEVICE) #, device=config.DEVICE
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
if patch_ids is not None:
patch_id = patch_ids[mlp_id]
else:
x_sample = feat_reshape
patch_id = []
patch_id = torch.randperm(feat_reshape.shape[1], device=config.DEVICE) #, device=config.DEVICE
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
return_ids.append(patch_id)
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
mlp = getattr(self, 'mlp_%d' % mlp_id)
x_sample = mlp(x_sample)
mlp_id += 1
return_ids.append(patch_id)
x_sample = normalize(x_sample)
norm = x_sample.pow(2).sum(1, keepdim=True).pow(1. / 2)
x_sample = x_sample.div(norm + 1e-7)

if num_patches == 0:
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
return_feats.append(x_sample)
return return_feats, return_ids

Expand Down
21 changes: 21 additions & 0 deletions resnet_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import torch.nn as nn

class ResnetBlock(nn.Module):
def __init__(self, features):
super().__init__()
layers = []
for i in range(2):
layers += [
nn.ReflectionPad2d(1),
nn.Conv2d(features, features, kernel_size=3),
nn.InstanceNorm2d(features),
]
if i==0:
layers += [
nn.ReLU(True)
]
self.model = nn.Sequential(*layers)

def forward(self, input):
return input + self.model(input)
68 changes: 42 additions & 26 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchvision.utils import save_image
from contrastive_discriminator_model import Discriminator
from contrastive_generator_model import Generator
from torch.optim import lr_scheduler


def patch_nce_loss(feat_q, feat_k):
Expand All @@ -20,7 +21,7 @@ def patch_nce_loss(feat_q, feat_k):
return loss

def calculate_NCE_loss(G, src, tgt):
feat_k_pool, sample_ids = G(src, encode_only=True, patch_ids=None)
feat_k_pool, sample_ids = G(src, encode_only=True)
feat_q_pool, _ = G(tgt, encode_only=True, patch_ids=sample_ids)
total_nce_loss = 0.0
for f_q, f_k in zip(feat_q_pool, feat_k_pool):
Expand All @@ -43,7 +44,13 @@ def main():
betas = (0.5, 0.999),
)
opt_mlp = optim.Adam(
itertools.chain(G.mlp_0.parameters(), G.mlp_1.parameters(), G.mlp_2.parameters(), G.mlp_3.parameters(), G.mlp_4.parameters()),
itertools.chain(
G.mlp_0.parameters(),
G.mlp_1.parameters(),
G.mlp_2.parameters(),
G.mlp_3.parameters(),
G.mlp_4.parameters()
),
lr = config.LEARNING_RATE,
betas = (0.5, 0.999),
)
Expand All @@ -65,6 +72,11 @@ def main():

out = dict()

lambdalr = lambda epoch: 1.0 - max(0, epoch - config.NUM_EPOCHS/2) / (config.NUM_EPOCHS/2)
scheduler_disc = lr_scheduler.LambdaLR(opt_disc, lr_lambda=lambdalr)
scheduler_gen = lr_scheduler.LambdaLR(opt_gen, lr_lambda=lambdalr)
scheduler_mlp = lr_scheduler.LambdaLR(opt_mlp, lr_lambda=lambdalr)

for epoch in range(config.NUM_EPOCHS):

X_reals = 0
Expand All @@ -79,14 +91,14 @@ def main():
X = X.to(config.DEVICE)

D_Y.set_requires_grad(True)
with torch.cuda.amp.autocast():
fake_Y = G(X)
D_Y_real = D_Y(Y)
D_Y_fake = D_Y(fake_Y.detach())
D_Y_real_loss = mse(D_Y_real, torch.ones_like(D_Y_real))
D_Y_fake_loss = mse(D_Y_fake, torch.zeros_like(D_Y_fake))
D_Y_loss = D_Y_real_loss + D_Y_fake_loss
D_loss = D_Y_loss

fake_Y = G(X)
D_Y_real = D_Y(Y)
D_Y_fake = D_Y(fake_Y.detach())
D_Y_real_loss = mse(D_Y_real, torch.ones_like(D_Y_real))
D_Y_fake_loss = mse(D_Y_fake, torch.zeros_like(D_Y_fake))
D_Y_loss = D_Y_real_loss + D_Y_fake_loss
D_loss = D_Y_loss

opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
Expand All @@ -96,25 +108,25 @@ def main():
d_scaler.update()

D_Y.set_requires_grad(False)
with torch.cuda.amp.autocast():
# adversarial loss for generator
D_Y_fake = D_Y(fake_Y)
loss_G_Y = mse(D_Y_fake, torch.ones_like(D_Y_fake))

# PatchNCE loss
PatchNCE_loss = calculate_NCE_loss(G, X, fake_Y)
# adversarial loss for generator
D_Y_fake = D_Y(fake_Y)
loss_G_Y = mse(D_Y_fake, torch.ones_like(D_Y_fake))

# identity loss
if config.LAMBDA_Y>0:
idt_Y = G(Y)
PatchNCE_loss += calculate_NCE_loss(G, idt_Y, fake_Y)
PatchNCE_loss /= 2
# PatchNCE loss
PatchNCE_loss = calculate_NCE_loss(G, X, fake_Y)

# identity loss
if config.LAMBDA_Y>0:
idt_Y = G(Y)
PatchNCE_loss += calculate_NCE_loss(G, idt_Y, fake_Y)
PatchNCE_loss /= 2

# add all togethor
G_loss = (
loss_G_Y
+ PatchNCE_loss * config.LAMBDA_X
)
# add all togethor
G_loss = (
loss_G_Y
+ PatchNCE_loss * config.LAMBDA_X
)

opt_gen.zero_grad()
opt_mlp.zero_grad()
Expand All @@ -140,6 +152,10 @@ def main():
out['D_loss'] += D_loss.item()
out['loss_G_Y'] += loss_G_Y.item()
out['PatchNCE_loss'] += PatchNCE_loss.item()
scheduler_disc.step()
scheduler_gen.step()
scheduler_mlp.step()
print(f"lr: disc={scheduler_disc.get_last_lr()} gen={scheduler_gen.get_last_lr()} mlp={scheduler_mlp.get_last_lr()}")

if epoch % 5 == 0 and config.SAVE_MODEL:
save_checkpoint(G, opt_gen, filename=f"saved_images_{name}/{epoch}_g.pth")
Expand Down

0 comments on commit d51c3a1

Please sign in to comment.