1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| import torch from torch import nn import torch.nn.functional as F
class CoordinateAttention(nn.Module): def __init__(self,in_ch,r): """ :param in_ch: 特征图输入通道数 :param r: 特征图通道数的减少比率 """ super(CoordinateAttention, self).__init__() self.x_avg_pool = nn.AdaptiveAvgPool2d((None,1)) self.y_avg_pool = nn.AdaptiveAvgPool2d((1,None)) self.conv_1x1 = nn.Sequential( nn.Conv2d(in_channels=in_ch, out_channels=in_ch // r, kernel_size=1, stride=1, padding=0), ) self.act = nn.Sigmoid() self.non_linear = nn.Sequential( nn.BatchNorm2d(in_ch // r), ) self.conv_x = nn.Conv2d( in_channels=in_ch//r,out_channels=in_ch,kernel_size=1,stride=1,padding=0 ) self.conv_y = nn.Conv2d( in_channels=in_ch // r, out_channels=in_ch, kernel_size=1, stride=1, padding=0 )
def forward(self,x): n,c,h,w = x.shape self.x_direction = self.x_avg_pool(x).transpose(2,3) self.y_direction = self.y_avg_pool(x) out = self.conv_1x1(torch.cat([self.x_direction,self.y_direction],dim=3)) out = self.non_linear(out) self.x,self.y = torch.split(out,[h,w],dim=3) self.x = self.x.transpose(2,3) out_x = F.sigmoid(self.conv_x(self.x)) out_y = F.sigmoid(self.conv_y(self.y)) return x * out_x * out_y
if __name__ == '__main__': x = torch.randn(2,16,32,32) print("input{}".format(x.shape)) N,C,H,W = x.shape network = CoordinateAttention(C,H,W,8) print("onput{}".format(network(x).shape))
|