forked from dilligencer-zrj/code_zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinitialization.py
59 lines (44 loc) · 1.62 KB
/
initialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def weights_init_xavier(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
init.uniform(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
# usage:
# net.apply(weights_init_xavier)
from torch import nn
class Initializer:
def __init__(self):
pass
@staticmethod
def initialize(model, initialization, **kwargs):
def weights_init(m):
if isinstance(m, nn.Conv2d):
initialization(m.weight.data, **kwargs)
try:
initialization(m.bias.data)
except:
pass
elif isinstance(m, nn.Linear):
initialization(m.weight.data, **kwargs)
try:
initialization(m.bias.data)
except:
pass
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.fill_(0)
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1.0)
m.bias.data.fill_(0)
model.apply(weights_init)
# usage
# net = Model() # instantiate the model
# to apply xavier_uniform:
# Initializer.initialize(model=net, initialization=init.xavier_uniform, gain=init.calculate_gain('relu'))
# or maybe normal distribution:
# Initializer.initialize(model=net, initialization=init.normal, mean=0, std=0.2)