Dataset-CelebA 主要包括transforms数据增强 、DatasetFolder加载无分类数据
1 2 3 4 5 6 transforms = transforms.Compose([ transforms.Resize(64 ), transforms.CenterCrop(64 ), transforms.ToTensor(), transforms.Normalize(mean=(0.5 , 0.5 , 0.5 ), std=(0.5 , 0.5 , 0.5 )) ])
DatasetFolder加载无分类数据 1 2 3 4 5 6 7 celeba_set = DatasetFolder( root=win_root, loader=loader, extensions="jpg" , transform= transforms )
完整代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 from torchvision.datasets import DatasetFolderfrom torchvision.transforms import transformsfrom PIL import Imagewin_root = "F:\dataset\CelebA\\" loader = lambda x : Image.open (x) extensions = "jpg" transforms = transforms.Compose([ transforms.Resize(64 ), transforms.CenterCrop(64 ), transforms.ToTensor(), transforms.Normalize(mean=(0.5 , 0.5 , 0.5 ), std=(0.5 , 0.5 , 0.5 )) ]) celeba_set = DatasetFolder( root=win_root, loader=loader, extensions="jpg" , transform= transforms )
Scheduler-LinearScheduler 主要包括diffuison过程 和denoise过程 两部分
初始化 初始化的目的就是先把所需要的变量值计算好
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def __init__ (self,time_step,beta_start,beta_end ): """ :param time_step: 时间步t :param beta_start: 调度器beta起始值 :param beta_end: 调度器beta结束值 """ self.time_step = time_step self.beta_start = beta_start self.beta_end = beta_end self.beta_t = torch.linspace(beta_start,beta_end,time_step) self.alpha_t = 1. - self.beta_t self.alpha_t_bar = torch.cumprod(self.alpha_t,dim=0 ) self.one_minus_alpha_t_bar = 1. - self.alpha_t_bar self.one_minus_alpha_t = 1. - self.alpha_t self.sqrt_alpha_t_bar = torch.sqrt(self.alpha_t_bar) self.sqrt_one_minus_alpha_t_bar = torch.sqrt(self.one_minus_alpha_t_bar)
diffuison过程 diffusion过程主要计算以下公式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def diffusion (self,x_0,t,noise ): """ 扩散模型的加噪过程 :param x_0: 输入源图像 x_0 :param t: 输入第t个时间步 :param noise: 对x_0加入的带噪图像 :return: 返回第t个时间刻的带噪图像 x_t """ sqrt_alpha_t_bar = self.sqrt_alpha_t_bar[t] sqrt_one_minus_alpha_t_bar = self.sqrt_one_minus_alpha_t_bar[t] for _ in range (len (x_0.shape) - 1 ): sqrt_alpha_t_bar = sqrt_alpha_t_bar.unsqueeze(-1 ) for _ in range (len (x_0.shape) - 1 ): sqrt_one_minus_alpha_t_bar = sqrt_one_minus_alpha_t_bar.unsqueeze(-1 ) x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise return x_t
denoise过程 denoise过程主要计算以下公式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def denoise (self,x_t,t,noise ): """ 扩散模型的去噪过程 :param x_t: 时间步t所对应的带噪图像 :param t: 时间步t :param noise: 需要去除的噪声图像 :return: 去噪后的图像x_{t-1},当t=0时直接返回均值,否则返回均值+方差 """ mean = (x_t - ((self.beta_t[t]) * noise / self.one_minus_alpha_t_bar[t])) / self.sqrt_alpha_t_bar[t] variance = (1. - self.alpha_t_bar[t-1 ]) * (1. - self.alpha_t[t]) / (self.one_minus_alpha_t_bar[t]) z = torch.randn(x_t.shape) if t == 0 : return mean return mean + variance ** 0.5 * z
完整代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 import torchfrom torch import nnclass LinearScheduler : """ 扩散模型的线性调度器,主要分为两个部分 1、扩散过程,用于获得任意时间刻的带噪图像 2、去噪过程,基于前一个时刻图像去噪以获得去噪图像 """ def __init__ (self,time_step,beta_start,beta_end ): """ :param time_step: 时间步t :param beta_start: 调度器beta起始值 :param beta_end: 调度器beta结束值 """ self.time_step = time_step self.beta_start = beta_start self.beta_end = beta_end self.beta_t = torch.linspace(beta_start,beta_end,time_step) self.alpha_t = 1. - self.beta_t self.alpha_t_bar = torch.cumprod(self.alpha_t,dim=0 ) self.one_minus_alpha_t_bar = 1. - self.alpha_t_bar self.one_minus_alpha_t = 1. - self.alpha_t self.sqrt_alpha_t_bar = torch.sqrt(self.alpha_t_bar) self.sqrt_one_minus_alpha_t_bar = torch.sqrt(self.one_minus_alpha_t_bar) def diffusion (self,x_0,t,noise ): """ 扩散模型的加噪过程 :param x_0: 输入源图像 x_0 :param t: 输入第t个时间步 :param noise: 对x_0加入的带噪图像 :return: 返回第t个时间刻的带噪图像 x_t """ sqrt_alpha_t_bar = self.sqrt_alpha_t_bar[t] sqrt_one_minus_alpha_t_bar = self.sqrt_one_minus_alpha_t_bar[t] for _ in range (len (x_0.shape) - 1 ): sqrt_alpha_t_bar = sqrt_alpha_t_bar.unsqueeze(-1 ) for _ in range (len (x_0.shape) - 1 ): sqrt_one_minus_alpha_t_bar = sqrt_one_minus_alpha_t_bar.unsqueeze(-1 ) x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise return x_t def denoise (self,x_t,t,noise ): """ 扩散模型的去噪过程 :param x_t: 时间步t所对应的带噪图像 :param t: 时间步t :param noise: 需要去除的噪声图像 :return: 去噪后的图像x_{t-1},当t=0时直接返回均值,否则返回均值+方差 """ mean = (x_t - ((self.beta_t[t]) * noise / self.one_minus_alpha_t_bar[t])) / self.sqrt_alpha_t_bar[t] variance = (1. - self.alpha_t_bar[t-1 ]) * (1. - self.alpha_t[t]) / (self.one_minus_alpha_t_bar[t]) z = torch.randn(x_t.shape) if t == 0 : return mean return mean + variance ** 0.5 * z
存在的问题 刚开始在写扩散过程时,代码报错维度不匹配,代码如下图所示 1 2 3 4 5 6 7 8 9 10 def diffusion (self,x_0,t,noise ): r""" 扩散模型的加噪过程 :param x_0: 输入源图像 x_0 :param t: 输入第t个时间步 :param noise: 对x_0加入的带噪图像 :return: 返回第t个时间刻的带噪图像 x_t """ x_t = self.sqrt_alpha_t_bar[t] * x_0 + self.sqrt_one_minus_alpha_t_bar[t] * noise return x_t
此时,alpha_t的维度为(N,),x_0的维度为(N,3,64,64),前者没有1,2,3维度,自然不能进行广播,因此需要将alpha_t的维度升维成(N,1,1,1)以匹配x_0,才能进行广播
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def diffusion (self,x_0,t,noise ): """ 扩散模型的加噪过程 :param x_0: 输入源图像 x_0 :param t: 输入第t个时间步 :param noise: 对x_0加入的带噪图像 :return: 返回第t个时间刻的带噪图像 x_t """ sqrt_alpha_t_bar = self.sqrt_alpha_t_bar[t] sqrt_one_minus_alpha_t_bar = self.sqrt_one_minus_alpha_t_bar[t] for _ in range (len (x_0.shape) - 1 ): sqrt_alpha_t_bar = sqrt_alpha_t_bar.unsqueeze(-1 ) for _ in range (len (x_0.shape) - 1 ): sqrt_one_minus_alpha_t_bar = sqrt_one_minus_alpha_t_bar.unsqueeze(-1 ) x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise return x_t
Model-UNet 主要包括time_embedding 、downblock 、midblock 、upblock
time_embeding 该模块的作用是将时间处理成一种位置信息,让扩散模型学习这种时间t与噪声的关系,采用的位置编码方式是和transformer一致的正余弦(sinusoidal)曲线编码,主要参考以下公式
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 import torchdef get_time_embedding (pos,d_model ): """ :param pos: 第pos个位置,标量 :param d_model: 位置编码对应的嵌入维度 :return: 第pos个位置的位置编码 """ factor = 10000 ** ( (torch.arange(0 ,d_model//2 )) / (d_model//2 ) ) pos = pos.unsqueeze(-1 ).repeat(1 , d_model // 2 ) angle = pos / factor pos_code = torch.cat([torch.sin(angle),torch.cos(angle)],dim=1 ) return pos_code
downblock 主要分为Resnet 和SA 模块,最终实现将输入【N,C1,H,W】 —- 【N,C2,H/2,W/2】 —- 【N,C3,H/4,W/4】—- 【N,C4,H/8,W/8】
Resnet模块 主要分为first_conv 和second_conv 两层卷积,采用k=3,s=1,p=1
的卷积使得卷积操作只改变通道数,而不改变图像尺寸大小
其次time_emb 会和first_conv 结合
最终张量【N,C1,H,W】经过Resnet模块得到输出【N,C2,H,W】,其中C2 = 2 * C1(比如C1=32,C2为64)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 def __init__ : self.first_conv = nn.Sequential( nn.GroupNorm(num_groups=8 , num_channels=in_ch), nn.SiLU(), nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3 , stride=1 , padding=1 ) ) self.second_conv = nn.Sequential( nn.GroupNorm(num_groups=8 , num_channels=out_ch), nn.SiLU(), nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3 , stride=1 , padding=1 ) ) self.time_proj = nn.Sequential( nn.SiLU(), nn.Linear(in_features=t_emb,out_features=out_ch) ) self.res_conv = nn.Sequential( nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=1 ,stride=1 ,padding=0 ) ) def forward : out = self.first_conv(input ) out = self.time_proj(t_emb).unsqueeze(-1 ).unsqueeze(-1 ) + out out = self.second_conv(out) out = out + self.res_conv(input )
SA模块 张量【N,C2,H,W】经过SA模块得到输出【N,C2,H,W】
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def __init__ : self.self_attn_norm = nn.GroupNorm(num_groups=8 ,num_channels=out_ch) self.self_attn = nn.MultiheadAttention(embed_dim=out_ch, num_heads=8 ,batch_first=True ) def forward : batch,ch,h,w = out.shape attn_in = out attn_in = attn_in.reshape(batch,ch,h*w) attn_in = self.self_attn_norm(attn_in) attn_in = attn_in.transpose(1 ,2 ) attn_out,_ = self.self_attn(attn_in,attn_in,attn_in) attn_out = attn_out.transpose(1 ,2 ) attn_out = attn_out.reshape(batch,ch,h,w) out = out + attn_out out = self.down_conv(out)
如何理解多头注意力机制nn.MultiheadAttention
的使用 QKV需要满足【N,L,E】的维度输入,我们的上一步的输入为【N,C2,H,W】,因此需要reshape成【N,C2,HW】再transpose成【N,HW,C】后才能作为输入数据
midblock 同理包含Restnet模块 和SA模块 ,最终实现将输入【N,C4,H/8,W/8】—-【N,C3,H/8,W/8】
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 def __init__ (self,in_ch,out_ch,t_emb,num_head=8 ): """ :param in_ch: 输入通道数 :param out_ch: 输出通道数 :param t_emb: 时间位置编码嵌入 :param num_head: 注意力机制多头数量 """ super (MidBlock, self).__init__() self.in_ch = in_ch self.out_ch = out_ch self.t_emb = t_emb self.first_conv = nn.Sequential( nn.GroupNorm(num_groups=8 ,num_channels=in_ch), nn.SiLU(), nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3 ,stride=1 ,padding=1 ) ) self.second_conv = nn.Sequential( nn.GroupNorm(num_groups=8 ,num_channels=out_ch), nn.SiLU(), nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3 ,stride=1 ,padding=1 ) ) self.t_proj= nn.Sequential( nn.SiLU(), nn.Linear(in_features=t_emb,out_features=out_ch) ) self.resnet_conv = nn.Sequential( nn.SiLU(), nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=1 ,stride=1 ,padding=0 ) ) self.third_conv = nn.Sequential( nn.SiLU(), nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3 ,stride=1 ,padding=1 ) ) self.self_attn_norm = nn.GroupNorm(num_groups=8 ,num_channels=out_ch) self.self_attn = nn.MultiheadAttention(embed_dim=out_ch,num_heads=num_head,batch_first=True ) def forward (self,input ,t ): out = self.first_conv(input ) out = out + self.t_proj(t).unsqueeze(-1 ).unsqueeze(-1 ) out = self.second_conv(out) out = out + self.resnet_conv(input ) out_copy_first = out batch,ch,h,w = out.shape out = out.reshape(batch,ch,h*w) in_attn = self.self_attn_norm(out) in_attn = in_attn.transpose(1 ,2 ) out_attn,_ = self.self_attn(in_attn,in_attn,in_attn) out = out_attn.transpose(1 ,2 ).reshape(batch,ch,h,w) out = out + out_copy_first out_copy_second = out out = self.third_conv(out) out = self.third_conv(out) out = out_copy_second + out return out
upblock 主要包含Restnet模块 ,最终实现将输入【N,C3,H/8,W/8】 —-【N,C2,H/4,W/4】—- 【N,C1,H/2,W/2】—- 【N,16,H,W】—- 【N,3,H,W】
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def __init__ (self,in_ch,out_ch ): """ :param in_ch: 特征图输入通道数 :param out_ch: 特征图输出通道数 """ super (UpBlock, self).__init__() self.in_ch = in_ch self.out_ch = out_ch self.transpose_conv = nn.Sequential( nn.ConvTranspose2d(in_channels=in_ch,out_channels=in_ch,kernel_size=4 ,stride=2 ,padding=1 ) ) self.upsample_conv = nn.Sequential( nn.Conv2d(in_channels=in_ch*2 ,out_channels=out_ch,kernel_size=3 ,stride=1 ,padding=1 ) ) def forward (self,input ,down_input ): out = self.transpose_conv(input ) out = torch.cat([out,down_input],dim=1 ) out = self.upsample_conv(out) return out
测试 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 if __name__ == '__main__': x = torch.randn(1,32,64,64) t = get_time_embedding(torch.as_tensor(10),128) down_1 = DownBlock(in_ch=32,out_ch=64,down_sample=True,t_emb=128) down_2 = DownBlock(in_ch=64,out_ch=128,down_sample=True,t_emb=128) mid = MidBlock(in_ch=128,out_ch=64,t_emb=128) up = UpBlock(in_ch=64,out_ch=32) down_1_ans = down_1(x,t) down_2_ans = down_2(down_1_ans,t) print("一次down:{}".format(down_1_ans.shape)) print("二次down:{}".format(down_2_ans.shape)) mid_ans = mid(down_2_ans,t) print("一次mid:{}".format(mid_ans.shape)) out_ans = up(mid_ans,down_1_ans) print("一次up:{}".format(out_ans.shape))
完整代码