背景


参考别人的代码中,它直接将一个标量进行reshape,导致我报错了,但它可以正常运行,目前还不知道是什么原因,后面再看看。
在扩散模型的Forward process中,可以根据以下公式计算得到t时刻的带噪图像xt,其中αt是表示噪声比例的一个系数,是一个标量。在Pytorch中一个标量不能直接和一个不同维度的张量进行广播操作,所以需要将一个标量转换为与目标张量形状相同。

方法


假设x0维度- [N,C,H,W] - [8,3,64,64] , 而αt维度为零维 , 所以想到了以下两种方式

1、先unsqueeze,再repeat,最后循环unsqueeze

1
2
3
4
5
sqrt_alpha_t_cumprod = self.sqrt_alpha_cumprod[t].unsqueeze(-1).repeat(batch_size)
sqrt_one_minus_alpha_t_cumprod = self.sqrt_one_minus_alpha_cumprod[t].unsqueeze(-1).repeat(batch_size)
for _ in range(len(original_shape)-1):
sqrt_alpha_t_cumprod = sqrt_alpha_t_cumprod.unsqueeze(-1)
sqrt_one_minus_alpha_t_cumprod = sqrt_one_minus_alpha_t_cumprod.unsqueeze(-1)

αt的维度变化 - [1] - [N] - [N,1] - [N,1,1] - [N,1,1,1]

2、循环unsqueeze,expand_as

1
2
3
4
5
6
7
sqrt_alpha_t_cumprod = self.sqrt_alpha_cumprod[t]
sqrt_one_minus_alpha_t_cumprod = self.sqrt_one_minus_alpha_cumprod[t]
for _ in range(len(original_shape)):
sqrt_alpha_t_cumprod = sqrt_alpha_t_cumprod.unsqueeze(-1)
sqrt_one_minus_alpha_t_cumprod = sqrt_one_minus_alpha_t_cumprod.unsqueeze(-1)
sqrt_alpha_t_cumprod = sqrt_alpha_t_cumprod.expand_as(original)
sqrt_one_minus_alpha_t_cumprod = sqrt_one_minus_alpha_t_cumprod.expand_as(noise)

αt的维度变化 - [1] - [1,1] - [1,1,1] - [1,1,1,1] - [8,3,64,64]

3、作者的方法

1
2
3
4
5
sqrt_alpha_cumprod = self.sqrt_alpha_cumprod[t].reshape(batch_size)
sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod[t].reshape(batch_size)
for _ in range(len(original_shape)-1):
sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1)
sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod.unsqueeze(-1)