【ICML-2021】SimAM

架构

能量函数

image.png

最小化能量函数的方程

image.png

最终的输出结果

image.png

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torch import nn
import torch.nn.functional as F

class SimAM(nn.Module):
def __init__(self,Lambda=0.1):
"""
:param Lambda: 能量函数里面的超参数
"""
super(SimAM, self).__init__()
self.Lambda = Lambda

def forward(self,x):
n,c,h,w = x.shape
M = h * w - 1
t_minus_mu_square = x - x.mean(dim=[2,3],keepdim = True).pow(2)
sigma_square = torch.sum(t_minus_mu_square) / M
one_divide_e = t_minus_mu_square / (4 * sigma_square + self.Lambda) + 0.5
out = F.sigmoid(one_divide_e) * x
return out

if __name__ == '__main__':
x = torch.randn(2,8,32,32)
print("iutput shape:{}".format(x.shape))
model = SimAM()
print("output shape:{}".format(model(x).shape))

FAQ