-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsqueeze_and_excitation.py
executable file
·144 lines (111 loc) · 4.83 KB
/
squeeze_and_excitation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Squeeze and Excitation Module
*****************************
Collection of squeeze and excitation classes where each can be inserted as a block into a neural network architechture
1. `Channel Squeeze and Excitation <https://arxiv.org/abs/1709.01507>`_
2. `Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_
3. `Channel and Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_
"""
from enum import Enum
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChannelSELayer(nn.Module):
"""
Re-implementation of Squeeze-and-Excitation (SE) block described in:
*Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
"""
def __init__(self, num_channels, reduction_ratio=2):
"""
:param num_channels: No of input channels
:param reduction_ratio: By how much should the num_channels should be reduced
"""
super(ChannelSELayer, self).__init__()
num_channels_reduced = num_channels // reduction_ratio
self.reduction_ratio = reduction_ratio
self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, input_tensor):
"""
:param input_tensor: X, shape = (batch_size, num_channels, H, W)
:return: output tensor
"""
batch_size, num_channels, H, W = input_tensor.size()
# Average along each channel
squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)
# channel excitation
fc_out_1 = self.relu(self.fc1(squeeze_tensor))
fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
a, b = squeeze_tensor.size()
output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1))
return output_tensor
class SpatialSELayer(nn.Module):
"""
Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in:
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
"""
def __init__(self, num_channels):
"""
:param num_channels: No of input channels
"""
super(SpatialSELayer, self).__init__()
self.conv = nn.Conv2d(num_channels, 1, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_tensor, weights=None):
"""
:param weights: weights for few shot learning
:param input_tensor: X, shape = (batch_size, num_channels, H, W)
:return: output_tensor
"""
# spatial squeeze
batch_size, channel, a, b = input_tensor.size()
if weights is not None:
weights = torch.mean(weights, dim=0)
weights = weights.view(1, channel, 1, 1)
out = F.conv2d(input_tensor, weights)
else:
out = self.conv(input_tensor)
squeeze_tensor = self.sigmoid(out)
# spatial excitation
# print(input_tensor.size(), squeeze_tensor.size())
squeeze_tensor = squeeze_tensor.view(batch_size, 1, a, b)
output_tensor = torch.mul(input_tensor, squeeze_tensor)
# output_tensor = torch.mul(input_tensor, squeeze_tensor)
return output_tensor
class ChannelSpatialSELayer(nn.Module):
"""
Re-implementation of concurrent spatial and channel squeeze & excitation:
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018, arXiv:1803.02579*
"""
def __init__(self, num_channels, reduction_ratio=2):
"""
:param num_channels: No of input channels
:param reduction_ratio: By how much should the num_channels should be reduced
"""
super(ChannelSpatialSELayer, self).__init__()
self.cSE = ChannelSELayer(num_channels, reduction_ratio)
self.sSE = SpatialSELayer(num_channels)
def forward(self, input_tensor):
"""
:param input_tensor: X, shape = (batch_size, num_channels, H, W)
:return: output_tensor
"""
output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
return output_tensor
class SELayer(Enum):
"""
Enum restricting the type of SE Blockes available. So that type checking can be adding when adding these blockes to
a neural network::
if self.se_block_type == se.SELayer.CSE.value:
self.SELayer = se.ChannelSpatialSELayer(params['num_filters'])
elif self.se_block_type == se.SELayer.SSE.value:
self.SELayer = se.SpatialSELayer(params['num_filters'])
elif self.se_block_type == se.SELayer.CSSE.value:
self.SELayer = se.ChannelSpatialSELayer(params['num_filters'])
"""
NONE = 'NONE'
CSE = 'CSE'
SSE = 'SSE'
CSSE = 'CSSE'