[CVPR-18]SENets

架构

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

import torch
from torch import nn

class SENets(nn.Module):
def __init__(self,in_ch,reduction_ratio):
super(SENets, self).__init__()
"""
:param in_ch: 特征图输入通道数
:param reduction_ratio: 特征图通道数的减少比率
"""
# [N,C,H,W] -> [N,C,1,1]
self.Global_pooling = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
# [N,C] -> [N,C/r]
nn.Linear(in_features= in_ch,out_features=in_ch//reduction_ratio),
nn.ReLU(),
# [N,C/r] -> [N,C]
nn.Linear(in_features=in_ch//reduction_ratio,out_features=in_ch)
)

def forward(self,x):
"""
:param x: 输入特征图 [N,C,H,W]
:return:
"""
N,C,H,W = x.shape
# [N,C,H,W] -> [N,C,1,1] -> [N,C]
out = self.Global_pooling(x).view(N,C)
# [N,C] -> [N,C]
out = self.fc(out)
# [N,C] -> [N,C,1,1]
out = out[:,:,None,None]
# [N,C,1,1] -> [N,C,H,W]
out = x * out
return out

if __name__ == '__main__':
x = torch.randn(2,64,32,32)
print("input_shape{}".format(x.shape))
model = SENets(in_ch=64,reduction_ratio=8)
out = model(x)
print("output_shape{}".format(out.shape))

FAQ

Q:模型中的这条红线是什么意思,残差连接吗?

生成的特征的经过$F_{sq}$、$F_{ex}$ 后得到的特征图为[N,C],需要依赖原始输入U进行$F_{scale}$,所以并不是残差连接

Q:池化层nn.AdaptiveAvgPool2d遇到的问题

不太理解nn.AdaptiveAvgPool2d(1) 的含义,其实它的意思是将[N,C,H,W]→[N,C,1,1]

Q:全连接层nn.Linear遇到的问题

nn.Linear 接受的输入为[N,C],因此输入全连接层之前需要把维度[N,C,1,1]压缩成[N,C]

1
2
# [N,C,H,W] -> [N,C,1,1] -> [N,C]
out = self.Global_pooling(x).view(N,C)

Q:论文涉及的一些细节

1
2
# [N,C,H,W] -> [N,C,1,1]
self.Global_pooling = nn.AdaptiveAvgPool2d(1)