Dataset-CelebA

主要包括transforms数据增强DatasetFolder加载无分类数据

transforms数据增强

1
2
3
4
5
6
transforms = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

DatasetFolder加载无分类数据

1
2
3
4
5
6
7
# 数据加载(无分类数据)
celeba_set = DatasetFolder(
root=win_root,
loader=loader,
extensions="jpg",
transform= transforms
)

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from torchvision.datasets import DatasetFolder
from torchvision.transforms import transforms
from PIL import Image

win_root = "F:\dataset\CelebA\\"
loader = lambda x : Image.open(x)
extensions = "jpg"
# 数据增强
transforms = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# 数据加载(无分类数据)
celeba_set = DatasetFolder(
root=win_root,
loader=loader,
extensions="jpg",
transform= transforms
)

# 测试
# if __name__ == '__main__':
# print(len(celeba_set.samples))

Scheduler-LinearScheduler

主要包括diffuison过程denoise过程两部分

初始化

初始化的目的就是先把所需要的变量值计算好

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def __init__(self,time_step,beta_start,beta_end):
"""
:param time_step: 时间步t
:param beta_start: 调度器beta起始值
:param beta_end: 调度器beta结束值
"""
self.time_step = time_step
self.beta_start = beta_start
self.beta_end = beta_end
self.beta_t = torch.linspace(beta_start,beta_end,time_step)
self.alpha_t = 1. - self.beta_t
self.alpha_t_bar = torch.cumprod(self.alpha_t,dim=0)
self.one_minus_alpha_t_bar = 1. - self.alpha_t_bar
self.one_minus_alpha_t = 1. - self.alpha_t
self.sqrt_alpha_t_bar = torch.sqrt(self.alpha_t_bar)
self.sqrt_one_minus_alpha_t_bar = torch.sqrt(self.one_minus_alpha_t_bar)

diffuison过程

diffusion过程主要计算以下公式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def diffusion(self,x_0,t,noise):
"""
扩散模型的加噪过程
:param x_0: 输入源图像 x_0
:param t: 输入第t个时间步
:param noise: 对x_0加入的带噪图像
:return: 返回第t个时间刻的带噪图像 x_t
"""
sqrt_alpha_t_bar = self.sqrt_alpha_t_bar[t]
sqrt_one_minus_alpha_t_bar = self.sqrt_one_minus_alpha_t_bar[t]
for _ in range(len(x_0.shape) - 1):
sqrt_alpha_t_bar = sqrt_alpha_t_bar.unsqueeze(-1)
for _ in range(len(x_0.shape) - 1):
sqrt_one_minus_alpha_t_bar = sqrt_one_minus_alpha_t_bar.unsqueeze(-1)
x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise
return x_t

denoise过程

denoise过程主要计算以下公式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def denoise(self,x_t,t,noise):
"""
扩散模型的去噪过程
:param x_t: 时间步t所对应的带噪图像
:param t: 时间步t
:param noise: 需要去除的噪声图像
:return: 去噪后的图像x_{t-1},当t=0时直接返回均值,否则返回均值+方差
"""
mean = (x_t - ((self.beta_t[t]) * noise / self.one_minus_alpha_t_bar[t])) / self.sqrt_alpha_t_bar[t]
variance = (1. - self.alpha_t_bar[t-1]) * (1. - self.alpha_t[t]) / (self.one_minus_alpha_t_bar[t])
z = torch.randn(x_t.shape)
if t == 0:
return mean
return mean + variance ** 0.5 * z

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from torch import nn

class LinearScheduler:
"""
扩散模型的线性调度器,主要分为两个部分
1、扩散过程,用于获得任意时间刻的带噪图像
2、去噪过程,基于前一个时刻图像去噪以获得去噪图像
"""

def __init__(self,time_step,beta_start,beta_end):
"""
:param time_step: 时间步t
:param beta_start: 调度器beta起始值
:param beta_end: 调度器beta结束值
"""
self.time_step = time_step
self.beta_start = beta_start
self.beta_end = beta_end
self.beta_t = torch.linspace(beta_start,beta_end,time_step)
self.alpha_t = 1. - self.beta_t
self.alpha_t_bar = torch.cumprod(self.alpha_t,dim=0)
self.one_minus_alpha_t_bar = 1. - self.alpha_t_bar
self.one_minus_alpha_t = 1. - self.alpha_t
self.sqrt_alpha_t_bar = torch.sqrt(self.alpha_t_bar)
self.sqrt_one_minus_alpha_t_bar = torch.sqrt(self.one_minus_alpha_t_bar)

def diffusion(self,x_0,t,noise):
"""
扩散模型的加噪过程
:param x_0: 输入源图像 x_0
:param t: 输入第t个时间步
:param noise: 对x_0加入的带噪图像
:return: 返回第t个时间刻的带噪图像 x_t
"""
sqrt_alpha_t_bar = self.sqrt_alpha_t_bar[t]
sqrt_one_minus_alpha_t_bar = self.sqrt_one_minus_alpha_t_bar[t]
for _ in range(len(x_0.shape) - 1):
sqrt_alpha_t_bar = sqrt_alpha_t_bar.unsqueeze(-1)
for _ in range(len(x_0.shape) - 1):
sqrt_one_minus_alpha_t_bar = sqrt_one_minus_alpha_t_bar.unsqueeze(-1)
x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise
return x_t

def denoise(self,x_t,t,noise):
"""
扩散模型的去噪过程
:param x_t: 时间步t所对应的带噪图像
:param t: 时间步t
:param noise: 需要去除的噪声图像
:return: 去噪后的图像x_{t-1},当t=0时直接返回均值,否则返回均值+方差
"""
mean = (x_t - ((self.beta_t[t]) * noise / self.one_minus_alpha_t_bar[t])) / self.sqrt_alpha_t_bar[t]
variance = (1. - self.alpha_t_bar[t-1]) * (1. - self.alpha_t[t]) / (self.one_minus_alpha_t_bar[t])
z = torch.randn(x_t.shape)
if t == 0:
return mean
return mean + variance ** 0.5 * z

# if __name__ == '__main__':
# x = torch.rand(2,3,64,64)
# noise = torch.rand(1, 3, 64, 64)
# Scheduler = LinearScheduler(1000,0.1,0.9)
# t = torch.randint(0, 1000, (x.shape[0],))
# tt = torch.as_tensor(0).unsqueeze(0)
# add_noise = Scheduler.diffusion(x,t,noise)
# print(add_noise.shape)
# denoise_noise = Scheduler.denoise(x,tt,noise)
# print(denoise_noise.shape)

存在的问题

刚开始在写扩散过程时,代码报错维度不匹配,代码如下图所示

1
2
3
4
5
6
7
8
9
10
def diffusion(self,x_0,t,noise):
r"""
扩散模型的加噪过程
:param x_0: 输入源图像 x_0
:param t: 输入第t个时间步
:param noise: 对x_0加入的带噪图像
:return: 返回第t个时间刻的带噪图像 x_t
"""
x_t = self.sqrt_alpha_t_bar[t] * x_0 + self.sqrt_one_minus_alpha_t_bar[t] * noise
return x_t

此时,alpha_t的维度为(N,),x_0的维度为(N,3,64,64),前者没有1,2,3维度,自然不能进行广播,因此需要将alpha_t的维度升维成(N,1,1,1)以匹配x_0,才能进行广播

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def diffusion(self,x_0,t,noise):
"""
扩散模型的加噪过程
:param x_0: 输入源图像 x_0
:param t: 输入第t个时间步
:param noise: 对x_0加入的带噪图像
:return: 返回第t个时间刻的带噪图像 x_t
"""
sqrt_alpha_t_bar = self.sqrt_alpha_t_bar[t]
sqrt_one_minus_alpha_t_bar = self.sqrt_one_minus_alpha_t_bar[t]
# 改进部分,unsqueeze在最后一个维度升维
for _ in range(len(x_0.shape) - 1):
sqrt_alpha_t_bar = sqrt_alpha_t_bar.unsqueeze(-1)
for _ in range(len(x_0.shape) - 1):
sqrt_one_minus_alpha_t_bar = sqrt_one_minus_alpha_t_bar.unsqueeze(-1)
x_t = sqrt_alpha_t_bar * x_0 + sqrt_one_minus_alpha_t_bar * noise
return x_t

Model-UNet

主要包括time_embeddingdownblockmidblockupblock

time_embeding

该模块的作用是将时间处理成一种位置信息,让扩散模型学习这种时间t与噪声的关系,采用的位置编码方式是和transformer一致的正余弦(sinusoidal)曲线编码,主要参考以下公式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
def get_time_embedding(pos,d_model):
"""
:param pos: 第pos个位置,标量
:param d_model: 位置编码对应的嵌入维度
:return: 第pos个位置的位置编码
"""
factor = 10000 ** (
(torch.arange(0,d_model//2)) / (d_model//2)
)
# pos -> [N,] -> [N,1] -> [N,d_model//2]
pos = pos.unsqueeze(-1).repeat(1, d_model // 2)
angle = pos / factor
pos_code = torch.cat([torch.sin(angle),torch.cos(angle)],dim=1)
return pos_code

# if __name__ == '__main__':
# x = torch.randn(1,3,64,64)
# tt = torch.randint(1, 2, (x.shape[0],))
# y = get_time_embedding(tt,32)

downblock

主要分为ResnetSA模块,最终实现将输入【N,C1,H,W】 —- 【N,C2,H/2,W/2】 —- 【N,C3,H/4,W/4】—- 【N,C4,H/8,W/8】

image-20240716160759224

Resnet模块

主要分为first_convsecond_conv两层卷积,采用k=3,s=1,p=1的卷积使得卷积操作只改变通道数,而不改变图像尺寸大小

其次time_emb会和first_conv结合

最终张量【N,C1,H,W】经过Resnet模块得到输出【N,C2,H,W】,其中C2 = 2 * C1(比如C1=32,C2为64)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__:
self.first_conv = nn.Sequential(
nn.GroupNorm(num_groups=8, num_channels=in_ch),
nn.SiLU(),
nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1)
)

self.second_conv = nn.Sequential(
nn.GroupNorm(num_groups=8, num_channels=out_ch),
nn.SiLU(),
nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1)
)

self.time_proj = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=t_emb,out_features=out_ch)
)

self.res_conv = nn.Sequential(
nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=1,stride=1,padding=0)
)

def forward:
#ResNet
out = self.first_conv(input)
out = self.time_proj(t_emb).unsqueeze(-1).unsqueeze(-1) + out
out = self.second_conv(out)
out = out + self.res_conv(input)

SA模块

张量【N,C2,H,W】经过SA模块得到输出【N,C2,H,W】

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def __init__:
self.self_attn_norm = nn.GroupNorm(num_groups=8,num_channels=out_ch)
self.self_attn = nn.MultiheadAttention(embed_dim=out_ch, num_heads=8,batch_first=True)

def forward:
#SA
batch,ch,h,w = out.shape
attn_in = out
attn_in = attn_in.reshape(batch,ch,h*w)
attn_in = self.self_attn_norm(attn_in)
attn_in = attn_in.transpose(1,2)
attn_out,_ = self.self_attn(attn_in,attn_in,attn_in)
attn_out = attn_out.transpose(1,2)
attn_out = attn_out.reshape(batch,ch,h,w)
out = out + attn_out
out = self.down_conv(out)

如何理解多头注意力机制nn.MultiheadAttention的使用

QKV需要满足【N,L,E】的维度输入,我们的上一步的输入为【N,C2,H,W】,因此需要reshape成【N,C2,HW】再transpose成【N,HW,C】后才能作为输入数据

image-20240712230237309

image-20240712230212731

midblock

同理包含Restnet模块SA模块,最终实现将输入【N,C4,H/8,W/8】—-【N,C3,H/8,W/8】

image-20240716160841252

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(self,in_ch,out_ch,t_emb,num_head=8):
"""
:param in_ch: 输入通道数
:param out_ch: 输出通道数
:param t_emb: 时间位置编码嵌入
:param num_head: 注意力机制多头数量
"""
super(MidBlock, self).__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.t_emb = t_emb

self.first_conv = nn.Sequential(
nn.GroupNorm(num_groups=8,num_channels=in_ch),
nn.SiLU(),
nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1)
)

self.second_conv = nn.Sequential(
nn.GroupNorm(num_groups=8,num_channels=out_ch),
nn.SiLU(),
nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1)
)

self.t_proj= nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=t_emb,out_features=out_ch)
)

self.resnet_conv = nn.Sequential(
nn.SiLU(),
nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=1,stride=1,padding=0)
)

self.third_conv = nn.Sequential(
nn.SiLU(),
nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1)
)

self.self_attn_norm = nn.GroupNorm(num_groups=8,num_channels=out_ch)
self.self_attn = nn.MultiheadAttention(embed_dim=out_ch,num_heads=num_head,batch_first=True)


def forward(self,input,t):
#Resnet
out = self.first_conv(input)
out = out + self.t_proj(t).unsqueeze(-1).unsqueeze(-1)
out = self.second_conv(out)
out = out + self.resnet_conv(input)
#SA
out_copy_first = out
batch,ch,h,w = out.shape
out = out.reshape(batch,ch,h*w)
in_attn = self.self_attn_norm(out)
in_attn = in_attn.transpose(1,2)
out_attn,_ = self.self_attn(in_attn,in_attn,in_attn)
out = out_attn.transpose(1,2).reshape(batch,ch,h,w)
out = out + out_copy_first
#Resnet
out_copy_second = out
out = self.third_conv(out)
out = self.third_conv(out)
out = out_copy_second + out
return out

upblock

主要包含Restnet模块,最终实现将输入【N,C3,H/8,W/8】 —-【N,C2,H/4,W/4】—- 【N,C1,H/2,W/2】—- 【N,16,H,W】—- 【N,3,H,W】

image-20240716160912258

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def __init__(self,in_ch,out_ch):
"""
:param in_ch: 特征图输入通道数
:param out_ch: 特征图输出通道数
"""
super(UpBlock, self).__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.transpose_conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=in_ch,out_channels=in_ch,kernel_size=4,stride=2,padding=1)
)
self.upsample_conv = nn.Sequential(
nn.Conv2d(in_channels=in_ch*2,out_channels=out_ch,kernel_size=3,stride=1,padding=1)
)

def forward(self,input,down_input):
out = self.transpose_conv(input)
out = torch.cat([out,down_input],dim=1)
out = self.upsample_conv(out)
return out

测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
if __name__ == '__main__':
x = torch.randn(1,32,64,64)
t = get_time_embedding(torch.as_tensor(10),128)
down_1 = DownBlock(in_ch=32,out_ch=64,down_sample=True,t_emb=128)
down_2 = DownBlock(in_ch=64,out_ch=128,down_sample=True,t_emb=128)
mid = MidBlock(in_ch=128,out_ch=64,t_emb=128)
up = UpBlock(in_ch=64,out_ch=32)

down_1_ans = down_1(x,t)
down_2_ans = down_2(down_1_ans,t)
print("一次down:{}".format(down_1_ans.shape))
print("二次down:{}".format(down_2_ans.shape))
mid_ans = mid(down_2_ans,t)
print("一次mid:{}".format(mid_ans.shape))
out_ans = up(mid_ans,down_1_ans)
print("一次up:{}".format(out_ans.shape))

完整代码