-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmeta_auto_maml.py
226 lines (167 loc) · 7.52 KB
/
meta_auto_maml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import time
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from learner import Network
import utils.utils as utils
from darts_architect import Architect
from copy import deepcopy
import pdb
class Meta(nn.Module):
"""
Meta Learner
"""
def __init__(self, args, criterion):
"""
:param args:
"""
super(Meta, self).__init__()
self.update_lr_theta = args.update_lr_theta
self.meta_lr_theta = args.meta_lr_theta
self.update_lr_w = args.update_lr_w
self.meta_lr_w = args.meta_lr_w
self.n_way = args.n_way
self.k_spt = args.k_spt
self.k_qry = args.k_qry
self.meta_batch_size = args.meta_batch_size
self.update_step = args.update_step
self.update_step_test = args.update_step_test
self.criterion = criterion
self.model = Network(args, args.init_channels, args.n_way, args.layers, criterion).cuda()
self.meta_optimizer_w = torch.optim.Adam(self.model.parameters(), lr=self.meta_lr_w)
self.inner_optimizer_w = torch.optim.SGD(self.model.parameters(), lr=self.update_lr_w)
def clip_grad_by_norm_(self, grad, max_norm):
"""
in-place gradient clipping.
:param grad: list of gradients
:param max_norm: maximum norm allowable
:return:
"""
total_norm = 0
counter = 0
for g in grad:
param_norm = g.data.norm(2)
total_norm += param_norm.item() ** 2
counter += 1
total_norm = total_norm ** (1. / 2)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for g in grad:
g.data.mul_(clip_coef)
return total_norm / counter
def _update_w(self, x_spt, y_spt, x_qry, y_qry):
meta_batch_size, setsz, c_, h, w = x_spt.shape
query_size = x_qry.shape[1]
corrects = [0 for _ in range(self.update_step + 1)]
''' copy weight and gradient '''
w_clone = dict([(k, v.clone()) for k, v in self.model.named_parameters()])
for p in self.model.parameters():
p.grad = torch.zeros_like(p.data)
grad_clone = [p.grad.clone() for p in self.model.parameters()]
for i in range(meta_batch_size):
# this is the loss and accuracy before first update
with torch.no_grad():
# [setsz, nway]
logits_q = self.model(x_qry[i], alphas=self.model.arch_parameters())
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry[i]).sum().item()
corrects[0] = corrects[0] + correct
# 1. run the i-th task and compute loss for k=0
logits = self.model(x_spt[i], alphas=self.model.arch_parameters()) # x_spt.shape
loss = self.criterion(logits, y_spt[i])
self.inner_optimizer_w.zero_grad()
loss.backward()
self.inner_optimizer_w.step()
# this is the loss and accuracy after the first update
with torch.no_grad():
# [setsz, nway]
logits_q = self.model(x_qry[i], alphas=self.model.arch_parameters())
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry[i]).sum().item()
corrects[1] = corrects[1] + correct
for k in range(1, self.update_step):
# 1. run the i-th task and compute loss for k=1~K-1
logits = self.model(x_spt[i], alphas=self.model.arch_parameters())
loss = self.criterion(logits, y_spt[i])
self.inner_optimizer_w.zero_grad()
loss.backward()
self.inner_optimizer_w.step()
logits_q = self.model(x_qry[i], alphas=self.model.arch_parameters())
loss_q = self.criterion(logits_q, y_qry[i])
with torch.no_grad():
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry[i]).sum().item() # convert to numpy
corrects[k + 1] = corrects[k + 1] + correct
''' Use first-order gradient average '''
self.inner_optimizer_w.zero_grad()
loss_q.backward()
grad_clone = [k + v.grad.clone() for k, v in zip(grad_clone, self.model.parameters())]
for k, v in self.model.named_parameters():
v.data.copy_(w_clone[k])
self.meta_optimizer_w.zero_grad()
for k, v in zip(grad_clone, self.model.parameters()):
v.grad.copy_(k / meta_batch_size)
self.meta_optimizer_w.step()
accs = np.array(corrects) / (query_size * meta_batch_size)
return accs
def forward(self, x_spt, y_spt, x_qry, y_qry, update_w_time):
"""
:param x_spt: [b, setsz, c_, h, w]
:param y_spt: [b, setsz]
:param x_qry: [b, query_size, c_, h, w]num_filter
:param y_qry: [b, query_size]
:return:
"""
start = time.time()
accs_w = self._update_w(x_spt, y_spt, x_qry, y_qry)
update_w_time.update(time.time() - start)
return accs_w, update_w_time
def _update_w_finetunning(self, model, inner_optimizer_w, x_spt, y_spt, x_qry, y_qry):
assert len(x_spt.shape) == 4
query_size = x_qry.shape[0]
corrects = [0 for _ in range(self.update_step_test + 1)]
# this is the loss and accuracy before first update
with torch.no_grad():
# [setsz, nway]
logits_q = model(x_qry, alphas=model.arch_parameters())
# [setsz]
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
# scalar
correct = torch.eq(pred_q, y_qry).sum().item()
corrects[0] = corrects[0] + correct
for k in range(self.update_step_test):
# 1. run the i-th task and compute loss for k=1~K-1
logits = model(x_spt, alphas=model.arch_parameters())
loss = self.criterion(logits, y_spt)
inner_optimizer_w.zero_grad()
loss.backward()
inner_optimizer_w.step()
with torch.no_grad():
logits_q = model(x_qry, alphas=model.arch_parameters())
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry).sum().item() # convert to numpy
corrects[k + 1] = corrects[k + 1] + correct
accs = np.array(corrects) / query_size
return accs
def finetunning(self, x_spt, y_spt, x_qry, y_qry, update_w_time):
"""
:param x_spt: [setsz, c_, h, w]
:param y_spt: [setsz]
:param x_qry: [query_size, c_, h, w]
:param y_qry: [query_size]
:return:
"""
model = deepcopy(self.model)
inner_optimizer_w = torch.optim.SGD(model.parameters(), lr=self.update_lr_w)
start = time.time()
accs_w_finetunning = self._update_w_finetunning(model, inner_optimizer_w, x_spt, y_spt, x_qry, y_qry)
update_w_time.update(time.time() - start)
del model
return accs_w_finetunning, update_w_time
def main():
pass
if __name__ == '__main__':
main()