-
Notifications
You must be signed in to change notification settings - Fork 3
/
dent.py
168 lines (139 loc) · 5.65 KB
/
dent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from copy import deepcopy
from sys import modules
import math
import torch
import torch.nn as nn
import torch.jit
from adamod import AdaMod
class Dent(nn.Module):
"""Dent adapts a model by entropy minimization during testing.
Once dented, a model adapts itself by updating on every forward.
"""
def __init__(self, model, opt_cfg):
super().__init__()
self.model = convert_batchnorm(model, globals()[opt_cfg.BN_FUN])
self.opt_cfg = opt_cfg
self.steps = opt_cfg.STEPS
self.criterion = globals()[opt_cfg.LOSS]
assert self.steps > 0, "dent requires >= 1 step(s) to forward and update"
def forward(self, x):
model = configure_batchnorm(x, self.model)
optimizer = setup_optimizer(
collect_params(model), self.opt_cfg)
for _ in range(self.steps):
forward_and_adapt(x, model, optimizer, self.criterion)
model.eval()
y = model(x)
return y
class SampleAwareStaticBatchNorm2d(nn.BatchNorm2d):
def forward(self, x):
scale = self.weight * ((self.running_var + self.eps).rsqrt()).reshape(1, -1)
bias = self.bias - self.running_mean.reshape(1, -1) * scale
scale = scale.unsqueeze(-1).unsqueeze(-1)
bias = bias.unsqueeze(-1).unsqueeze(-1)
return x * scale + bias
class SampleAwareOnlineBatchNorm2d(nn.BatchNorm2d):
def forward(self, x):
current_mean = x.mean([0, 2, 3])
current_var = x.var([0, 2, 3], unbiased=False)
scale = self.weight * ((current_var + self.eps).rsqrt()).reshape(1, -1)
bias = self.bias - current_mean.reshape(1, -1) * scale
scale = scale.unsqueeze(-1).unsqueeze(-1)
bias = bias.unsqueeze(-1).unsqueeze(-1)
return x * scale + bias
@torch.jit.script
def tent(x: torch.Tensor) -> torch.Tensor:
"""Entropy of softmax distribution from logits."""
return -(x.softmax(1) * x.log_softmax(1)).sum(1).mean(0)
@torch.jit.script
def shot(x: torch.Tensor) -> torch.Tensor:
loss_ent = x.softmax(1)
loss_ent = -torch.sum(loss_ent * torch.log(loss_ent + 1e-5), dim=1).mean(0)
loss_div = x.softmax(1).mean(0)
loss_div = torch.sum(loss_div * torch.log(loss_div + 1e-5))
return loss_ent + loss_div
@torch.enable_grad() # ensure grads in possible no grad context for testing
def forward_and_adapt(x, model, optimizer, criterion):
"""Forward and adapt model on batch of data.
Measure entropy of the model prediction, take gradients, and update params.
"""
# forward
outputs = model(x)
# adapt
loss = criterion(outputs)
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
def setup_optimizer(params, config):
"""Set up optimizer for dent adaptation.
Dent needs an optimizer for test-time entropy minimization.
In principle, dent could make use of any gradient optimizer.
In practice, we advise choosing AdaMod.
For optimization settings, we advise to use the settings from the end of
trainig, if known, or start with a low learning rate (like 0.001) if not.
For best results, try tuning the learning rate and batch size.
"""
if config.METHOD == 'AdaMod':
return AdaMod(params,
lr=config.LR,
betas=(config.BETA, 0.999),
beta3=config.BETA3,
weight_decay=config.WD)
elif config.METHOD == 'Adam':
return torch.optim.Adam(params,
lr=config.LR,
betas=(config.BETA, 0.999),
weight_decay=config.WD)
elif config.METHOD == 'SGD':
return torch.optim.SGD(params,
lr=config.LR,
momentum=config.MOMENTUM,
dampening=config.DAMPENING,
weight_decay=config.WD,
nesterov=config.NESTEROV)
else:
raise NotImplementedError
def copy_model_state(model):
"""Copy the model states for resetting after adaptation."""
model_state = deepcopy(model.state_dict())
return model_state
def load_model_state(model, model_state):
"""Restore the model states from copies."""
model.load_state_dict(model_state, strict=True)
def collect_params(model):
"""Collect optim params for use with dent."""
params = []
for p in model.parameters():
if p.requires_grad:
params.append(p)
return params
def configure_batchnorm(x, model):
"""Configure model for use with dent."""
bs = x.size(0)
# train mode, because dent optimizes the model to minimize entropy
model.train()
# disable grad, to (re-)enable only what dent updates
model.requires_grad_(False)
# configure norm for dent updates:
# enable grad + keep statisics + repeat affine params
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.weight = nn.Parameter(m.ckpt_weight.unsqueeze(0).repeat(bs, 1))
m.bias = nn.Parameter(m.ckpt_bias.unsqueeze(0).repeat(bs, 1))
m.requires_grad_(True)
return model
def convert_batchnorm(module, bn_fun):
module_output = module
if isinstance(module, nn.BatchNorm2d):
module_output = bn_fun(module.num_features, module.eps)
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.register_buffer("ckpt_weight", module.weight)
module_output.register_buffer("ckpt_bias", module.bias)
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
for name, child in module.named_children():
module_output.add_module(name, convert_batchnorm(child, bn_fun))
del module
return module_output