【CVPR-21】Coordinate Attention

架构

image.png

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torch import nn
import torch.nn.functional as F

class CoordinateAttention(nn.Module):
def __init__(self,in_ch,r):
"""
:param in_ch: 特征图输入通道数
:param r: 特征图通道数的减少比率
"""
super(CoordinateAttention, self).__init__()
self.x_avg_pool = nn.AdaptiveAvgPool2d((None,1))
self.y_avg_pool = nn.AdaptiveAvgPool2d((1,None))
self.conv_1x1 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=in_ch // r, kernel_size=1, stride=1, padding=0),
)
self.act = nn.Sigmoid()
self.non_linear = nn.Sequential(
nn.BatchNorm2d(in_ch // r),
)
self.conv_x = nn.Conv2d(
in_channels=in_ch//r,out_channels=in_ch,kernel_size=1,stride=1,padding=0
)
self.conv_y = nn.Conv2d(
in_channels=in_ch // r, out_channels=in_ch, kernel_size=1, stride=1, padding=0
)

def forward(self,x):
n,c,h,w = x.shape
# [N,C,H,W] -> [N,C,H,1] -> [N,C,1,H]
self.x_direction = self.x_avg_pool(x).transpose(2,3)
# [N,C,H,W] -> [N,C,1,W]
self.y_direction = self.y_avg_pool(x)
# [N,C,1,H+W]
out = self.conv_1x1(torch.cat([self.x_direction,self.y_direction],dim=3))
out = self.non_linear(out)
# [N,C,1,H+W] -> [N,C,1,H] + [N,C,1,W]
self.x,self.y = torch.split(out,[h,w],dim=3)
# [N,C,1,H] -> [N,C,H,1]
self.x = self.x.transpose(2,3)
out_x = F.sigmoid(self.conv_x(self.x))
out_y = F.sigmoid(self.conv_y(self.y))
# [N,C,H,W] * [N,C,H,1] * [N,C,1,W]
return x * out_x * out_y

if __name__ == '__main__':
x = torch.randn(2,16,32,32)
print("input{}".format(x.shape))
N,C,H,W = x.shape
network = CoordinateAttention(C,H,W,8)
print("onput{}".format(network(x).shape))

FAQ

对不同通道单独进行池化操作

1
2
3
4
# 只对W通道池化
nn.AdaptiveAvgPool2d((None,1))
# 只对H通道池化
nn.AdaptiveAvgPool2d((1,None))

对一个张量沿着某个通道拆分为两份

沿着dim=3维度拆分张量为两部分,第一个部分大小为h,第二个部分大小为w

1
2
# [N,C,1,H+W] -> [N,C,1,H] + [N,C,1,W]
self.x,self.y = torch.split(out,[h,w],dim=3)