背景
参考别人的代码中,它直接将一个标量进行reshape,导致我报错了,但它可以正常运行,目前还不知道是什么原因,后面再看看。
在扩散模型的Forward process中,可以根据以下公式计算得到t时刻的带噪图像xt,其中αt是表示噪声比例的一个系数,是一个标量。在Pytorch中一个标量不能直接和一个不同维度的张量进行广播操作,所以需要将一个标量转换为与目标张量形状相同。
方法
假设x0维度- [N,C,H,W] - [8,3,64,64] , 而αt维度为零维 , 所以想到了以下两种方式
1、先unsqueeze,再repeat,最后循环unsqueeze
1 2 3 4 5
| sqrt_alpha_t_cumprod = self.sqrt_alpha_cumprod[t].unsqueeze(-1).repeat(batch_size) sqrt_one_minus_alpha_t_cumprod = self.sqrt_one_minus_alpha_cumprod[t].unsqueeze(-1).repeat(batch_size) for _ in range(len(original_shape)-1): sqrt_alpha_t_cumprod = sqrt_alpha_t_cumprod.unsqueeze(-1) sqrt_one_minus_alpha_t_cumprod = sqrt_one_minus_alpha_t_cumprod.unsqueeze(-1)
|
αt的维度变化 - [1] - [N] - [N,1] - [N,1,1] - [N,1,1,1]
2、循环unsqueeze,expand_as
1 2 3 4 5 6 7
| sqrt_alpha_t_cumprod = self.sqrt_alpha_cumprod[t] sqrt_one_minus_alpha_t_cumprod = self.sqrt_one_minus_alpha_cumprod[t] for _ in range(len(original_shape)): sqrt_alpha_t_cumprod = sqrt_alpha_t_cumprod.unsqueeze(-1) sqrt_one_minus_alpha_t_cumprod = sqrt_one_minus_alpha_t_cumprod.unsqueeze(-1) sqrt_alpha_t_cumprod = sqrt_alpha_t_cumprod.expand_as(original) sqrt_one_minus_alpha_t_cumprod = sqrt_one_minus_alpha_t_cumprod.expand_as(noise)
|
αt的维度变化 - [1] - [1,1] - [1,1,1] - [1,1,1,1] - [8,3,64,64]
3、作者的方法
1 2 3 4 5
| sqrt_alpha_cumprod = self.sqrt_alpha_cumprod[t].reshape(batch_size) sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod[t].reshape(batch_size) for _ in range(len(original_shape)-1): sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1) sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod.unsqueeze(-1)
|