-
Notifications
You must be signed in to change notification settings - Fork 47
/
norm.py
65 lines (53 loc) · 2.11 KB
/
norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from copy import deepcopy
import torch
import torch.nn as nn
class Norm(nn.Module):
"""Norm adapts a model by estimating feature statistics during testing.
Once equipped with Norm, the model normalizes its features during testing
with batch-wise statistics, just like batch norm does during training.
"""
def __init__(self, model, eps=1e-5, momentum=0.1,
reset_stats=False, no_stats=False):
super().__init__()
self.model = model
self.model = configure_model(model, eps, momentum, reset_stats,
no_stats)
self.model_state = deepcopy(self.model.state_dict())
def forward(self, x):
return self.model(x)
def reset(self):
self.model.load_state_dict(self.model_state, strict=True)
def collect_stats(model):
"""Collect the normalization stats from batch norms.
Walk the model's modules and collect all batch normalization stats.
Return the stats and their names.
"""
stats = []
names = []
for nm, m in model.named_modules():
if isinstance(m, nn.BatchNorm2d):
state = m.state_dict()
if m.affine:
del state['weight'], state['bias']
for ns, s in state.items():
stats.append(s)
names.append(f"{nm}.{ns}")
return stats, names
def configure_model(model, eps, momentum, reset_stats, no_stats):
"""Configure model for adaptation by test-time normalization."""
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
# use batch-wise statistics in forward
m.train()
# configure epsilon for stability, and momentum for updates
m.eps = eps
m.momentum = momentum
if reset_stats:
# reset state to estimate test stats without train stats
m.reset_running_stats()
if no_stats:
# disable state entirely and use only batch stats
m.track_running_stats = False
m.running_mean = None
m.running_var = None
return model