【ECCV-18】CBAM

架构

输入特征图F(N,C,H,W)

经过CA模块转化为F’(N,C,H,W)* CA(N,C,1,1)→ (N,C,H,W)

经过SA模块转化为F’’(N,C,H,W)*SA(N,1,H,W)→(N,C,H,W)

最后残差连接

Untitled

Untitled

Untitled

实现

CA模块

image.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class ChannelAttentionModule(nn.Module):
def __init__(self,in_ch,reduction_ratio):
"""
:param in_ch: 特征图输入通道数
:param reduction_ratio: 特征图通道数的减少比率
"""
super(ChannelAttentionModule, self).__init__()
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Sequential(
nn.Conv2d(in_channels=in_ch,out_channels=in_ch // reduction_ratio,kernel_size=3,stride=1,padding=1),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Conv2d(in_channels=in_ch // reduction_ratio,out_channels=in_ch,kernel_size=3,stride=1,padding=1),
nn.ReLU()
)
self.act = nn.Sigmoid()

def forward(self,x):
# [N,C,H,W] -> [N,C,1,1]
max_part = self.fc2(self.fc1(self.max_pool(x)))
# [N,C,H,W] -> [N,C,1,1]
avg_part = self.fc2(self.fc1(self.avg_pool(x)))
# [N,C,1,1] -> [N,C,1,1]
out = max_part + avg_part
out = self.act(out)
return out
if __name__ == '__main__':
x = torch.rand(1,64,16,16)
N,C,H,W = x.shape
model_1 = ChannelAttentionModule(in_ch=C,reduction_ratio=8)
# torch.Size([1, 64, 1, 1])
print(model_1(x).shape)

SA模块

image.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class SpatialAttentionModule(nn.Module):
def __init__(self):
"""
:param in_ch: 特征图输入通道数
"""
super(SpatialAttentionModule, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=2,out_channels=1,kernel_size=7,padding=3),
nn.ReLU()
)
self.act = nn.Sigmoid()

def forward(self,x):
# [N,C,H,W] -> [N,1,H,W]
avg = torch.mean(x,dim=1,keepdim=True)
# [N,C,H,W] -> [N,1,H,W]
max,_ = torch.max(x,dim=1,keepdim=True)
# [N,1,H,W] -> [N,2,H,W]
out = torch.concatenate((avg,max),dim=1)
out = self.act(self.conv(out))
return out
if __name__ == '__main__':
x = torch.rand(1,64,16,16)
N,C,H,W = x.shape
model_2 = SpatialAttentionModule()
# torch.Size([1, 1, 16, 16])
print(model_2(x).shape)

总体代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from torch import nn

class ChannelAttentionModule(nn.Module):
def __init__(self,in_ch,reduction_ratio):
"""
:param in_ch: 特征图输入通道数
:param reduction_ratio: 特征图通道数的减少比率
"""
super(ChannelAttentionModule, self).__init__()
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Sequential(
nn.Conv2d(in_channels=in_ch,out_channels=in_ch // reduction_ratio,kernel_size=3,stride=1,padding=1),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Conv2d(in_channels=in_ch // reduction_ratio,out_channels=in_ch,kernel_size=3,stride=1,padding=1),
nn.ReLU()
)
self.act = nn.Sigmoid()

def forward(self,x):
max_part = self.fc2(self.fc1(self.max_pool(x)))
avg_part = self.fc2(self.fc1(self.avg_pool(x)))
out = max_part + avg_part
out = self.act(out)
return out

class SpatialAttentionModule(nn.Module):
def __init__(self):
"""
:param in_ch: 特征图输入通道数
"""
super(SpatialAttentionModule, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=2,out_channels=1,kernel_size=7,padding=3),
nn.ReLU()
)
self.act = nn.Sigmoid()

def forward(self,x):
# [N,C,H,W] -> [N,1,H,W]
avg = torch.mean(x,dim=1,keepdim=True)
# [N,C,H,W] -> [N,1,H,W]
max,_ = torch.max(x,dim=1,keepdim=True)
# [N,1,H,W] -> [N,2,H,W]
out = torch.concatenate((avg,max),dim=1)
out = self.act(self.conv(out))
return out

class CBAM(nn.Module):
def __init__(self,in_ch,reduction_ratio):
super(CBAM, self).__init__()
self.ca = ChannelAttentionModule(in_ch=in_ch,reduction_ratio=reduction_ratio)
self.sa = SpatialAttentionModule()

def forward(self,x):
# [N,C,H,W] -> [N,C,H,W]
ca_out = self.ca(x) * x
# [N,C,H,W] -> [N,C,H,W]
sa_out = self.sa(ca_out) * ca_out
return sa_out

if __name__ == '__main__':
x = torch.rand(1,64,16,16)
N,C,H,W = x.shape
model = CBAM(in_ch=C,reduction_ratio=8)
print(model(x).shape)
# model_1 = ChannelAttentionModule(in_ch=C,reduction_ratio=8)
# print(model_1(x).shape)
#
# model_2 = SpatialAttentionModule(in_ch=C)
# print(model_2(x).shape)

FAQ

如何对特征图的通道方向做平均池化和最大池化

原因是 nn.AdaptiveMaxPool2d(1)nn.AdaptiveAvgPool2d(1) 的作用都是对H、W方向进行池化

比如将[N,C,H,W]→[N,C,1,1]

torch.max 报错

TypeError: expected Tensor as element 1 in argument 0, but got torch.return_types.max

感觉没有什么问题,于是检查对应的代码部分

1
2
3
4
5
6
# [N,C,H,W] -> [N,1,H,W]
avg = torch.mean(x,dim=1,keepdim=True)
# [N,C,H,W] -> [N,1,H,W]
max = torch.max(x,dim=1,keepdim=True)
# [N,1,H,W] -> [N,2,H,W]
out = torch.concatenate((avg,max),dim=1)

打断点看看怎么回事,torch.max有两个返回值,第一个返回值是max的值,第二个返回值是max对应的下标

image.png

cbam模块的sa模块报错

RuntimeError: The size of tensor a (10) must match the size of tensor b (16) at non-singleton dimension 3

查了一下发现是卷积层的padding没有设置,默认为0,这样就不能保证特征图的大小不变

1
2
3
4
self.conv = nn.Sequential(
nn.Conv2d(in_channels=2,out_channels=1,kernel_size=7),
nn.ReLU()
)

观察该公式可以知道当$K = 2P+1$时能保证经过卷积后的特征图大小不变,即

1
2
3
4

K = 7,S = 1,P = 3
K = 3,S = 1,P = 1
K = 1,S = 1,P = 0

因此修改之后成功运行

1
2
3
4
self.conv = nn.Sequential(
nn.Conv2d(in_channels=2,out_channels=1,kernel_size=7,padding=3),
nn.ReLU()
)