defforward(self,x): """ :param x: 输入特征图 [N,C,H,W] :return: """ N,C,H,W = x.shape # [N,C,H,W] -> [N,C,1,1] -> [N,C] out = self.Global_pooling(x).view(N,C) # [N,C] -> [N,C] out = self.fc(out) # [N,C] -> [N,C,1,1] out = out[:,:,None,None] # [N,C,1,1] -> [N,C,H,W] out = x * out return out
if __name__ == '__main__': x = torch.randn(2,64,32,32) print("input_shape{}".format(x.shape)) model = SENets(in_ch=64,reduction_ratio=8) out = model(x) print("output_shape{}".format(out.shape))