文章目录
一、代码框架
一、Diffusion.py
这个代码文件用于实现扩散模型DDPM的理论公式推导部分。
# 提取出当前时间步tdefextract(v, t, x_shape):"""
Extract some coefficients at specified timesteps, then reshape to
[batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
device = t.deviced
out = torch.gather(v, index=t, dim=0).float().to(device)return out.view([t.shape[0]]+[1]*(len(x_shape)-1))
v是包含随时间变化的系数的tensor(比如alphas_bar),对每个时间步t,从中提取出对应的值。
训练过程
GaussianDiffusionTrainer
负责模型的训练,即学习在给定的时间步t预测噪声。
classGaussianDiffusionTrainer(nn.Module):
计算各个常数系数:
def__init__(self, model, beta_1, beta_T, T):super().__init__()
self.model = model
self.T = T
# betas缓冲区,包含了从1到T的beta的值
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
alphas =1.- self.betas
alphas_bar = torch.cumprod(alphas, dim=0)# alpha累积# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer('sqrt_one_minus_alphas_bar', torch.sqrt(1.- alphas_bar))
在构造函数中,模型、β值范围以及总时间步数 T 被传递进来。这里还计算了 α 值(1-β),累积的 α 值(α_bar),以及与扩散过程相关的根号下的 α_bar 和 1-α_bar 的值,用于后续的计算。
训练:
根据论文,优化目标最终可以简化为缩小预测噪声和实际噪声的差值,就可以实现对前一步状态的均值的最优预测模型。
defforward(self, x_0):"""
Algorithm 1.
"""
t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device)# 随机选择一个时间步t
noise = torch.randn_like(x_0)
x_t =(
extract(self.sqrt_alphas_bar, t, x_0.shape)* x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape)* noise)# 预测xt处对应噪声和实际噪声对比的损失函数
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')return loss
这一步是根据如下公式计算出x_t,并计算出在t时间步时预测噪声和实际噪声的MSE损失:
采样过程
classGaussianDiffusionSampler(nn.Module):
初始化:
def__init__(self, model, beta_1, beta_T, T):super().__init__()
self.model = model
self.T = T
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
alphas =1.- self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar,[1,0], value=1)[:T]
self.register_buffer('coeff1', torch.sqrt(1./ alphas))
self.register_buffer('coeff2', self.coeff1 *(1.- alphas)/ torch.sqrt(1.- alphas_bar))
self.register_buffer('posterior_var', self.betas *(1.- alphas_bar_prev)/(1.- alphas_bar))
计算逆扩散过程中涉及的系数coeff1(保留原始信息占比)、coeff2(扰动噪声占比)和posterior_var(后验方差):
均值(只和x_t和z_t有关):
方差:
预测:
根据后验条件概率和前面给出的均值和方差计算公式,对前一步状态进行计算:
均值计算公式:
# 预测先前xt的均值defpredict_xt_prev_mean_from_eps(self, x_t, t, eps):assert x_t.shape == eps.shape
return(
extract(self.coeff1, t, x_t.shape)* x_t -
extract(self.coeff2, t, x_t.shape)* eps
)
得到均值和方差,其中的噪声项eps用模型预测:
defp_mean_variance(self, x_t, t):# below: only log_variance is used in the KL computations
var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
var = extract(var, t, x_t.shape)
eps = self.model(x_t, t)
xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)return xt_prev_mean, var
采样:
defforward(self, x_T):"""
Algorithm 2.
"""
x_t = x_T
for time_step inreversed(range(self.T)):print(time_step)
t = x_t.new_ones([x_T.shape[0],], dtype=torch.long)* time_step
mean, var= self.p_mean_variance(x_t=x_t, t=t)# 不是最后一步加噪 no noise when t == 0if time_step >0:
noise = torch.randn_like(x_t)else:
noise =0
x_t = mean + torch.sqrt(var)* noise
assert torch.isnan(x_t).int().sum()==0,"nan in tensor."
x_0 = x_t
return torch.clip(x_0,-1,1)
对每个时间步,用刚刚训练好的模型预测前一状态的均值和方差,并添加新的随机噪声。
二、Model.py
这个代码文件用于构造基于DDPM的U-Net架构。
# 激活函数classSwish(nn.Module):defforward(self, x):return x * torch.sigmoid(x)
时间嵌入
classTimeEmbedding(nn.Module):
初始化:
def__init__(self, T, d_model, dim):assert d_model %2==0# 检查维度是否为偶数,每一维对应一对正余弦值super().__init__()
emb = torch.arange(0, d_model, step=2)/ d_model * math.log(10000)# 频率向量
emb = torch.exp(-emb)
pos = torch.arange(T).float()
emb = pos[:,None]* emb[None,:]# 时间嵌入的基频矩阵assertlist(emb.shape)==[T, d_model //2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)assertlist(emb.shape)==[T, d_model //2,2]
emb = emb.view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),)
self.initialize()
通过位置编码(Positional Encoding)机制,给定序列中的每个位置添加时间信息。
- 检查维度为偶数,便于对每个维度分配一个正余弦值对。
- 频率向量计算,用于后续生成正弦和余弦波形。
- 通过将位置索引
pos[:, None]
和频率向量emb[None, :]
相乘生成基频矩阵。 - 分别计算正余弦函数值,将它们堆叠起来形成时间嵌入矩阵emb。
- 构建时间嵌入层。
definitialize(self):for module in self.modules():ifisinstance(module, nn.Linear):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)defforward(self, t):
emb = self.timembedding(t)return emb
initialize初始化模型中的权重,forward规定模型从时间嵌入层获取时间信息。
上采样和下采样:
classDownSample(nn.Module):def__init__(self, in_ch):super().__init__()
self.main = nn.Conv2d(in_ch, in_ch,3, stride=2, padding=1)
self.initialize()definitialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)defforward(self, x, temb):
x = self.main(x)return x
classUpSample(nn.Module):def__init__(self, in_ch):super().__init__()
self.main = nn.Conv2d(in_ch, in_ch,3, stride=1, padding=1)
self.initialize()definitialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)defforward(self, x, temb):
_, _, H, W = x.shape
x = F.interpolate(
x, scale_factor=2, mode='nearest')
x = self.main(x)return x
在下采样中,stride=2表明特征图片的空间尺寸减半。
在上采样中,用F.interpolat双线性插值使空间尺寸扩大两倍。
从而输入和输出的维度不变。
自注意力机制:
classAttnBlock(nn.Module):def__init__(self, in_ch):super().__init__()
self.group_norm = nn.GroupNorm(32, in_ch)
self.proj_q = nn.Conv2d(in_ch, in_ch,1, stride=1, padding=0)
self.proj_k = nn.Conv2d(in_ch, in_ch,1, stride=1, padding=0)
self.proj_v = nn.Conv2d(in_ch, in_ch,1, stride=1, padding=0)
self.proj = nn.Conv2d(in_ch, in_ch,1, stride=1, padding=0)
self.initialize()definitialize(self):for module in[self.proj_q, self.proj_k, self.proj_v, self.proj]:
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.proj.weight, gain=1e-5)defforward(self, x):
B, C, H, W = x.shape
h = self.group_norm(x)
q = self.proj_q(h)
k = self.proj_k(h)
v = self.proj_v(h)
q = q.permute(0,2,3,1).view(B, H * W, C)
k = k.view(B, C, H * W)
w = torch.bmm(q, k)*(int(C)**(-0.5))assertlist(w.shape)==[B, H * W, H * W]
w = F.softmax(w, dim=-1)
v = v.permute(0,2,3,1).view(B, H * W, C)
h = torch.bmm(w, v)assertlist(h.shape)==[B, H * W, C]
h = h.view(B, H, W, C).permute(0,3,1,2)
h = self.proj(h)return x + h
这个注意力模块的主要功能是在输入特征图上应用自我注意力机制,从而在不改变输入和输出尺寸的情况下增强模型的表达能力。通过这种注意力机制,模型可以学会关注输入中的重要部分,忽略不相关的信息,提高整体性能。
残差块:
classResBlock(nn.Module):def__init__(self, in_ch, out_ch, tdim, dropout, attn=False):super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(32, in_ch),
Swish(),
nn.Conv2d(in_ch, out_ch,3, stride=1, padding=1),)
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),)
self.block2 = nn.Sequential(
nn.GroupNorm(32, out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch,3, stride=1, padding=1),)if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch,1, stride=1, padding=0)else:
self.shortcut = nn.Identity()if attn:
self.attn = AttnBlock(out_ch)else:
self.attn = nn.Identity()
self.initialize()definitialize(self):for module in self.modules():ifisinstance(module,(nn.Conv2d, nn.Linear)):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)defforward(self, x, temb):
h = self.block1(x)
h += self.temb_proj(temb)[:,:,None,None]
h = self.block2(h)
h = h + self.shortcut(x)
h = self.attn(h)return h
U-net架构:
classUNet(nn.Module):def__init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):super().__init__()assertall([i <len(ch_mult)for i in attn]),'attn index out of bound'
tdim = ch *4
self.time_embedding = TimeEmbedding(T, ch, tdim)
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
下采样:
self.downblocks = nn.ModuleList()
chs =[ch]# record output channel when dowmsample for upsample
now_ch = ch
for i, mult inenumerate(ch_mult):
out_ch = ch * mult
for _ inrange(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)if i !=len(ch_mult)-1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
- 由于上采样需要匹配下采样的通道数,因此用
chs=[ch]
记录下采样的输出通道数; - now_ch记录当前通道数,初始化为ch;
- 对每个通道放大倍数mult,添加若干残差块,用于解决梯度消失问题,并更新当前通道数为残差块的输出通道数;
- 如果不是最后一次操作,继续添加一个下采样层,用于减小特征图尺寸来简化计算,但不改变通道数。
中间块:
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),])
在上下采样层之间添加中间块,加强学习能力。
上采样:
self.upblocks = nn.ModuleList()for i, mult inreversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ inrange(num_res_blocks +1):
self.upblocks.append(ResBlock(
in_ch=chs.pop()+ now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i !=0:
self.upblocks.append(UpSample(now_ch))assertlen(chs)==0
和下采样过程类似,添加残差块和上采样层
输出层:
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch,3,3, stride=1, padding=1))
self.initialize()definitialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)
nn.GroupNorm(32, now_ch)
将输入通道分为32组进行归一化;Swish
是激活函数;- 输出通道数为3,即RGB通道图像,输出特征图的尺寸不变;
- 初始化权重和偏置。
基于U-net定义前向传播函数:
defforward(self, x, t):# Timestep embedding
temb = self.time_embedding(t)# Downsampling
h = self.head(x)
hs =[h]for layer in self.downblocks:
h = layer(h, temb)
hs.append(h)# Middlefor layer in self.middleblocks:
h = layer(h, temb)# Upsamplingfor layer in self.upblocks:ifisinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb)
h = self.tail(h)assertlen(hs)==0return h
在主程序汇总实例化测试模型:
if __name__ =='__main__':
batch_size =8
model = UNet(
T=1000, ch=128, ch_mult=[1,2,2,2], attn=[1],
num_res_blocks=2, dropout=0.1)
x = torch.randn(batch_size,3,32,32)
t = torch.randint(1000,(batch_size,))
y = model(x, t)print(y.shape)
三、Train.py
下面是基于CIFAR-10数据集对DDPM进行训练和评估的代码。
训练
deftrain(modelConfig: Dict):
device = torch.device(modelConfig["device"])# dataset
dataset = CIFAR10(
root='./CIFAR10', train=True, download=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),]))
dataloader = DataLoader(
dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
选取设备为CPU或GPU,加载数据集CIFAR-10为dataset,并设置dataloader用于加载数据集。
# model setup
net_model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)if modelConfig["training_load_weight"]isnotNone:
net_model.load_state_dict(torch.load(os.path.join(
modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
optimizer = torch.optim.AdamW(
net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
warmUpScheduler = GradualWarmupScheduler(
optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"]//10, after_scheduler=cosineScheduler)
trainer = GaussianDiffusionTrainer(
net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
实例化UNet模型、优化器、学习率、模型训练器。
# start trainingfor e inrange(modelConfig["epoch"]):with tqdm(dataloader, dynamic_ncols=True)as tqdmDataLoader:for images, labels in tqdmDataLoader:# train
optimizer.zero_grad()
x_0 = images.to(device)
loss = trainer(x_0).sum()/1000.
loss.backward()
torch.nn.utils.clip_grad_norm_(
net_model.parameters(), modelConfig["grad_clip"])
optimizer.step()
tqdmDataLoader.set_postfix(ordered_dict={"epoch": e,"loss: ": loss.item(),"img shape: ": x_0.shape,"LR": optimizer.state_dict()['param_groups'][0]["lr"]})
warmUpScheduler.step()
torch.save(net_model.state_dict(), os.path.join(
modelConfig["save_weight_dir"],'ckpt_'+str(e)+"_.pt"))
训练步骤:(对每个训练周期,即一个epoch)
- 清除梯度
- 将输入数据移动到指定设备
- 计算损失
- 梯度裁剪,防止梯度爆炸
- 更新模型参数
- 更新学习率,保存当前权重
评估
defeval(modelConfig: Dict):# load model and evaluatewith torch.no_grad():
device = torch.device(modelConfig["device"])
model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
num_res_blocks=modelConfig["num_res_blocks"], dropout=0.)
ckpt = torch.load(os.path.join(
modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)
model.load_state_dict(ckpt)print("model load weight done.")
model.eval()
sampler = GaussianDiffusionSampler(
model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)# Sampled from standard normal distribution
noisyImage = torch.randn(
size=[modelConfig["batch_size"],3,32,32], device=device)
saveNoisy = torch.clamp(noisyImage *0.5+0.5,0,1)
save_image(saveNoisy, os.path.join(
modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
sampledImgs = sampler(noisyImage)
sampledImgs = sampledImgs *0.5+0.5# [0 ~ 1]
save_image(sampledImgs, os.path.join(
modelConfig["sampled_dir"], modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])
四、Main.py
定义主函数用于启动DDPM的训练或评估,取决于传入的state键值。
from Diffusion.Train import train,evaldefmain(model_config =None):
modelConfig ={"state":"train",# or eval"epoch":200,"batch_size":80,"T":1000,"channel":128,"channel_mult":[1,2,3,4],"attn":[2],"num_res_blocks":2,"dropout":0.15,"lr":1e-4,"multiplier":2.,"beta_1":1e-4,"beta_T":0.02,"img_size":32,"grad_clip":1.,"device":"cuda:0",### MAKE SURE YOU HAVE A GPU !!!"training_load_weight":None,"save_weight_dir":"./Checkpoints/","test_load_weight":"ckpt_199_.pt","sampled_dir":"./SampledImgs/","sampledNoisyImgName":"NoisyNoGuidenceImgs.png","sampledImgName":"SampledNoGuidenceImgs.png","nrow":8}if model_config isnotNone:
modelConfig = model_config
if modelConfig["state"]=="train":
train(modelConfig)else:eval(modelConfig)if __name__ =='__main__':
main()
版权归原作者 Smiling639 所有, 如有侵权,请联系我们删除。