1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| import torch from torch import nn import torch.nn.functional as F
class SimAM(nn.Module): def __init__(self,Lambda=0.1): """ :param Lambda: 能量函数里面的超参数 """ super(SimAM, self).__init__() self.Lambda = Lambda
def forward(self,x): n,c,h,w = x.shape M = h * w - 1 t_minus_mu_square = x - x.mean(dim=[2,3],keepdim = True).pow(2) sigma_square = torch.sum(t_minus_mu_square) / M one_divide_e = t_minus_mu_square / (4 * sigma_square + self.Lambda) + 0.5 out = F.sigmoid(one_divide_e) * x return out
if __name__ == '__main__': x = torch.randn(2,8,32,32) print("iutput shape:{}".format(x.shape)) model = SimAM() print("output shape:{}".format(model(x).shape))
|