-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
121 lines (108 loc) · 5.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Sort out the imports!
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from functools import partial
from copy import deepcopy
from model import *
from utils import *
class AdaINLoss(nn.Module):
def __init__(self, lambda_):
super().__init__()
self.lambda_ = lambda_
# lambda is a hyperparameter the dictates the relative importance of content vs style
# the greater lambda is the more the model will try to preserve style
# the smaller lambda is the more the model will try and preserve content
# [See equation 11 of the Paper]
def contentLoss(self, content_emb, output_emb):
""" Takes 2 embedding tensors generated by vgg and finds the L2 norm
(ie. euclidan distance) between them. [See equation 12 of the Paper]"""
return torch.norm(content_emb-output_emb)
def styleLoss(self, style_activations, output_activations):
""" Takes 2 lists of activation tensors hooked from vgg layers during
forward passes using our style image and our ouput image as inputs.
Computes the L2 norm between each of their means and standard deviations
and returns the sum. [See equation 13 of the Paper]"""
mu_sum = 0
sigma_sum = 0
for style_act, output_act in zip(style_activations, output_activations):
mu_sum = torch.norm(mu(style_act)-mu(output_act))
sigma_sum = torch.norm(sigma(style_act)-sigma(output_act))
return mu_sum + sigma_sum
def totalLoss(self, content_emb, output_emb, style_activations, output_activations):
""" Calculates the overall loss. [See equation 11 of the Paper]"""
content_loss = self.contentLoss(content_emb, output_emb)
style_loss = self.styleLoss(style_activations, output_activations)
#print(content_loss.item(), style_loss.item())
return content_loss+self.lambda_*style_loss
def forward(self, content_emb, output_emb, style_activations, output_activations):
""" For caculating single image loss please pass arguments with a batch size of 1. """
return self.totalLoss(content_emb, output_emb, style_activations, output_activations)/content_emb.shape[0]
if __name__ == "__main__":
# Set hyperparameters
bs = 1
epochs = 5
lr=6e-4
wd=0.001
lambda_ = 7.5
alpha = 0.1
style_layers = ['1','6','10','20']
debug_layers = [0,3,5,7]
# Load content and style datasets
content_images = DataLoader(ImageDataset('./train/content'), batch_size=bs, shuffle=True, num_workers=0)
style_images = DataLoader(ImageDataset('./train/style'), batch_size=bs, shuffle=True, num_workers=0)
# Set cost function
criterion = AdaINLoss(lambda_)
# Load Model
model = AdaINStyle()
for p in model.vgg.parameters():
p.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
# declare variables to store hooks
activations = [None]*4
debug_activations = [None]*4
debug_grads = [None]*4
# declare hook function
def styleHook(i, module, input, output):
global activations
activations[i] = output
def debugHook(i, module, input, output):
global activations
debug_activations[i] = output
# establish hooks in vgg
for i, layer in enumerate(style_layers):
model.vgg._modules[layer].register_forward_hook(partial(styleHook,i))
for i, layer in enumerate(debug_layers):
model.dec._modules[str(layer)].register_forward_hook(partial(debugHook,i))
for epoch in range(epochs):
i=0
running_loss = 0
for content_batch, style_batch in zip(content_images, style_images):
i += 1
optimizer.zero_grad()
output = model(content_batch,style_batch)
content_emb = model.t
style_activations = deepcopy(activations)
output_emb = model.vgg(output)
output_activations = activations
#print(output, content_emb)
loss = criterion(content_emb, output_emb, style_activations, output_activations)
if torch.isnan(loss).item():
print("Got nan, here's the activations:")
print(debug_activations)
print(debug_grads_old)
quit()
loss.backward()
debug_grads_old = deepcopy(debug_grads)
debug_grads = [model.dec[layer].weight.grad for layer in debug_layers]
optimizer.step()
running_loss = alpha*loss.item() + (1-alpha)*running_loss
print(running_loss)
#for layer in model:
# print(layer.grad)
save_image(output, f"./tmp/{epoch}_{i}_o.png",2)
save_image(torch.sigmoid(content_batch), f"./tmp/{epoch}_{i}_c.png",2)
save_image(torch.sigmoid(style_batch), f"./tmp/{epoch}_{i}_s.png",2)
#save_image(torch.sigmoid(output), f"./tmp/o_{epoch}_{i}.png")
print(f'Epoch: [{epoch}/{epochs}], Loss: {loss}')
torch.save(model, 'adain_model')