0


扩散模型(Diffusion model)代码详细解读

扩散模型代码详细解读

代码地址:denoising-diffusion-pytorch/denoising_diffusion_pytorch.py at main · lucidrains/denoising-diffusion-pytorch (github.com)

前向过程和后向过程的代码都在

  1. GaussianDiffusion

​这个类中。​

有问题可以一起讨论!

常见问题解决

Why self-conditioning? · Issue #94 · lucidrains/denoising-diffusion-pytorch (github.com)

"pred_x0" preforms better than "pred_noise" · Issue #58 · lucidrains/denoising-diffusion-pytorch (github.com)

What is objective=pred_x0 and how do you use it? · Issue #34 · lucidrains/denoising-diffusion-pytorch (github.com)

Conditional generation · Issue #7 · lucidrains/denoising-diffusion-pytorch (github.com)

Questions About DDPM · Issue #10 · lucidrains/denoising-diffusion-pytorch (github.com)
The difference between pred_x0, pred_v, pred_noise three objectives · Issue #153 · lucidrains/denoising-diffusion-pytorch (github.com)

前向训练过程

p_losses

首先是p_losses函数,这个是训练过程的主体部分。

  1. defp_losses(self, x_start, t, noise =None):
  2. b, c, h, w = x_start.shape
  3. # 首先随机生成噪声
  4. noise = default(noise,lambda: torch.randn_like(x_start))# noise sample# 噪声采样,注意这个是一次性完成的
  5. x = self.q_sample(x_start = x_start, t = t, noise = noise)# if doing self-conditioning, 50% of the time, predict x_start from current set of times# and condition with unet with that# this technique will slow down training by 25%, but seems to lower FID significantly# 判断是否进行self-condition,就是利用前面步骤预测出的x0来辅助当前的预测
  6. x_self_cond =Noneif self.self_condition and random()<0.5:with torch.no_grad():
  7. x_self_cond = self.model_predictions(x, t).pred_x_start
  8. x_self_cond.detach_()# predict and take gradient step# 将采样的x和self condition的x一起输入到model当中,这个model是UNet结构
  9. model_out = self.model(x, t, x_self_cond)# 模型预测的目标,分为三种if self.objective =='pred_noise':
  10. target = noise
  11. elif self.objective =='pred_x0':
  12. target = x_start
  13. elif self.objective =='pred_v':
  14. v = self.predict_v(x_start, t, noise)
  15. target = v
  16. else:raise ValueError(f'unknown objective {self.objective}')# 计算损失
  17. loss = self.loss_fn(model_out, target, reduction ='none')
  18. loss =reduce(loss,'b ... -> b (...)','mean')
  19. loss = loss * extract(self.p2_loss_weight, t, loss.shape)return loss.mean()

对其中的extract函数进行分析,extract函数实现如下:

  1. defextract(a, t, x_shape):# Extract some coefficients at specified timesteps,# then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
  2. b,*_ = t.shape
  3. # 使用了gather函数
  4. out = a.gather(-1, t)return out.reshape(b,*((1,)*(len(x_shape)-1)))

q_sample

然后介绍p_losses函数中使用的其他函数,第一个是q_sample函数,它的作用是加上噪声,对应论文的公式:
在这里插入图片描述

其中

  1. self.sqrt_alphas_cumprod

​和

  1. self.sqrt_one_minus_alphas_cumprod

​分别是alpha的累乘值和1-alpha的累乘值,x_start相当于x0,noise相当于z。

  1. defq_sample(self, x_start, t, noise=None):
  2. noise = default(noise,lambda: torch.randn_like(x_start))return(
  3. extract(self.sqrt_alphas_cumprod, t, x_start.shape)* x_start +
  4. extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)* noise
  5. )

model_predictions

然后是model_predictions函数,它的实现如下:

  1. defmodel_predictions(self, x, t, x_self_cond =None, clip_x_start =False):# 输入到UNet结构中获得输出
  2. model_output = self.model(x, t, x_self_cond)
  3. maybe_clip = partial(torch.clamp,min=-1.,max=1.)if clip_x_start else identity
  4. # 暂不明确它的作用if self.objective =='pred_noise':
  5. pred_noise = model_output
  6. x_start = self.predict_start_from_noise(x, t, pred_noise)
  7. x_start = maybe_clip(x_start)elif self.objective =='pred_x0':
  8. x_start = model_output
  9. x_start = maybe_clip(x_start)
  10. pred_noise = self.predict_noise_from_start(x, t, x_start)elif self.objective =='pred_v':
  11. v = model_output
  12. x_start = self.predict_start_from_v(x, t, v)
  13. x_start = maybe_clip(x_start)
  14. pred_noise = self.predict_noise_from_start(x, t, x_start)# 返回得到的噪声和return ModelPrediction(pred_noise, x_start)

几种objective

model_predictions函数中有一个难点,就是其中的self.objective,它有三种形式:

  • pred_noise:这个相当于是预测噪声,此时UNet模型的输出是噪声
  • pred_x0:这个相当于是预测最开始的x,此时UNet模型的输出是去噪的图像
  • pred_v:这个相当于是预测速度v,它在这篇文章中提出。然后根据速度求出最开始的x,最后预测出噪声。

如图所示:​
在这里插入图片描述

在上面的三种objective中,还涉及到了几种预测方法的实现,具体如下:

(1)predict_start_from_noise:这个函数的作用是根据噪声noise预测最开始的x,也就是去噪的图像。

其中

  1. self.sqrt_recip_alphas_cumprod

​和

  1. self.sqrt_recipm1_alphas_cumprod

​来自在这里插入图片描述
公式,它们分别为:在这里插入图片描述
在这里插入图片描述

公式来源文章:DDPM

  1. defpredict_start_from_noise(self, x_t, t, noise):return(
  2. extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape)* x_t -
  3. extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)* noise
  4. )

它对应论文中的公式如下:
在这里插入图片描述

(2)predict_noise_from_start:这个函数的作用是根据图像预测噪声,也就是加噪声。

  1. defpredict_noise_from_start(self, x_t, t, x0):return((extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape)* x_t - x0)/ \
  2. extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape))

它对应论文中的公式如下:
在这里插入图片描述
需要注意它是反推过来的,过程如下:

(3)predict_v:预测速度v

  1. defpredict_v(self, x_start, t, noise):return(
  2. extract(self.sqrt_alphas_cumprod, t, x_start.shape)* noise -
  3. extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)* x_start
  4. )

它对应论文中的公式:在这里插入图片描述

(4)predict_start_from_v:根据速度v预测最初的x,也就是图像

  1. defpredict_start_from_v(self, x_t, t, v):return(
  2. extract(self.sqrt_alphas_cumprod, t, x_t.shape)* x_t -
  3. extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)* v
  4. )

它对应论文中的公式如下:在这里插入图片描述其中zt相当于xt。

后向采样过程

sample函数

  1. @torch.no_grad()defsample(self, batch_size =16, return_all_timesteps =False):
  2. image_size, channels = self.image_size, self.channels
  3. # 采样的函数
  4. sample_fn = self.p_sample_loop ifnot self.is_ddim_sampling else self.ddim_sample
  5. # 调用该函数return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)

该函数的作用是获取采样的函数然后进行调用,采样函数分成两种:p_sample_loop和ddim_sample。

p_sample_loop函数

  1. @torch.no_grad()defp_sample_loop(self, shape, return_all_timesteps =False):
  2. batch, device = shape[0], self.betas.device
  3. # 随机生成噪声图像
  4. img = torch.randn(shape, device = device)
  5. imgs =[img]
  6. x_start =None# 遍历所有的tfor t in tqdm(reversed(range(0, self.num_timesteps)), desc ='sampling loop time step', total = self.num_timesteps):# 判断是否使用self-condition
  7. self_cond = x_start if self.self_condition elseNone# 进行采样,得到去噪的图像
  8. img, x_start = self.p_sample(img, t, self_cond)
  9. imgs.append(img)# 判断是否返回每个步骤的img还是最后一步的img
  10. ret = img ifnot return_all_timesteps else torch.stack(imgs, dim =1)# 归一化
  11. ret = self.unnormalize(ret)return ret

其中涉及到归一化函数

  1. self.unnormalize

​,含有两种

  1. # normalization functionsdefnormalize_to_neg_one_to_one(img):return img *2-1defunnormalize_to_zero_to_one(t):return(t +1)*0.5

p_sample函数

  1. @torch.no_grad()defp_sample(self, x, t:int, x_self_cond =None):
  2. b,*_, device =*x.shape, x.device
  3. batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)# 获得平均值,方差和x0
  4. model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised =True)# 随机生成一个噪声
  5. noise = torch.randn_like(x)if t >0else0.# no noise if t == 0# 得到预测的图像,img = 平均值 + exp(0.5 * 方差) * noise
  6. pred_img = model_mean +(0.5* model_log_variance).exp()* noise
  7. return pred_img, x_start

p_mean_variance函数

其中含有

  1. p_mean_variance

​函数,代码实现如下:

  1. defp_mean_variance(self, x, t, x_self_cond =None, clip_denoised =True):# 输入到UNet网络进行预测
  2. preds = self.model_predictions(x, t, x_self_cond)# 得到预测的x0
  3. x_start = preds.pred_x_start
  4. # 压缩x0中值的范围至[-1,1]if clip_denoised:
  5. x_start.clamp_(-1.,1.)# 得到x0后根据xtt得到分布的平均值和方差
  6. model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)return model_mean, posterior_variance, posterior_log_variance, x_start

q_posterior函数

其中

  1. q_posterior

​函数的实现如下:

  1. defq_posterior(self, x_start, x_t, t):# 计算平均值
  2. posterior_mean =(
  3. extract(self.posterior_mean_coef1, t, x_t.shape)* x_start +
  4. extract(self.posterior_mean_coef2, t, x_t.shape)* x_t
  5. )# 计算方差
  6. posterior_variance = extract(self.posterior_variance, t, x_t.shape)# 获得一个压缩范围的方差,且取对数
  7. posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)return posterior_mean, posterior_variance, posterior_log_variance_clipped

平均值和方差对应的公式如下:

在这里插入图片描述

其中

  1. self.posterior_mean_coef1

​对应的是x0前面的系数,

  1. self.posterior_mean_coef2

​对应的是xt前面的系数。

  1. self.posterior_variance

​对应的beta那部分的系数。

ddim_sample函数

  1. @torch.no_grad()defddim_sample(self, shape, return_all_timesteps =False):
  2. batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
  3. times = torch.linspace(-1, total_timesteps -1, steps = sampling_timesteps +1)# [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
  4. times =list(reversed(times.int().tolist()))
  5. time_pairs =list(zip(times[:-1], times[1:]))# [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
  6. img = torch.randn(shape, device = device)
  7. imgs =[img]
  8. x_start =Nonefor time, time_next in tqdm(time_pairs, desc ='sampling loop time step'):
  9. time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
  10. self_cond = x_start if self.self_condition elseNone
  11. pred_noise, x_start,*_ = self.model_predictions(img, time_cond, self_cond, clip_x_start =True)
  12. imgs.append(img)if time_next <0:
  13. img = x_start
  14. continue
  15. alpha = self.alphas_cumprod[time]
  16. alpha_next = self.alphas_cumprod[time_next]
  17. sigma = eta *((1- alpha / alpha_next)*(1- alpha_next)/(1- alpha)).sqrt()
  18. c =(1- alpha_next - sigma **2).sqrt()
  19. noise = torch.randn_like(img)
  20. img = x_start * alpha_next.sqrt()+ \
  21. c * pred_noise + \
  22. sigma * noise
  23. ret = img ifnot return_all_timesteps else torch.stack(imgs, dim =1)
  24. ret = self.unnormalize(ret)return ret

上面部分依据的公式为:(文章)
在这里插入图片描述
在这里插入图片描述

训练的模型(UNet)

后续会继续更新!
对您有帮助请点赞收藏哦!


本文转载自: https://blog.csdn.net/qq_41234663/article/details/128780745
版权归原作者 lzl2040 所有, 如有侵权,请联系我们删除。

“扩散模型(Diffusion model)代码详细解读”的评论:

还没有评论