-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbn_fold.py
67 lines (50 loc) · 1.9 KB
/
bn_fold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
def fuse_single_conv_bn_pair(block1, block2):
if isinstance(block1, nn.BatchNorm2d) and isinstance(block2, nn.Conv2d):
m = block1
conv = block2
bn_st_dict = m.state_dict()
conv_st_dict = conv.state_dict()
# BatchNorm params
eps = m.eps
mu = bn_st_dict['running_mean']
var = bn_st_dict['running_var']
gamma = bn_st_dict['weight']
if 'bias' in bn_st_dict:
beta = bn_st_dict['bias']
else:
beta = torch.zeros(gamma.size(0)).float().to(gamma.device)
# Conv params
W = conv_st_dict['weight']
if 'bias' in conv_st_dict:
bias = conv_st_dict['bias']
else:
bias = torch.zeros(W.size(0)).float().to(gamma.device)
denom = torch.sqrt(var + eps)
b = beta - gamma.mul(mu).div(denom)
A = gamma.div(denom)
bias *= A
A = A.expand_as(W.transpose(0, -1)).transpose(0, -1)
W.mul_(A)
bias.add_(b)
conv.weight.data.copy_(W)
if conv.bias is None:
conv.bias = torch.nn.Parameter(bias)
else:
conv.bias.data.copy_(bias)
return conv
else:
return False
def fuse_bn_recursively(model):
previous_name = None
for module_name in model._modules:
previous_name = module_name if previous_name is None else previous_name # Initialization
conv_fused = fuse_single_conv_bn_pair(model._modules[module_name], model._modules[previous_name])
if conv_fused:
model._modules[previous_name] = conv_fused
model._modules[module_name] = nn.Identity()
if len(model._modules[module_name]._modules) > 0:
fuse_bn_recursively(model._modules[module_name])
previous_name = module_name
return model