0


扩散模型 (Diffusion Model) 简要介绍与源码分析

扩散模型 (Diffusion Model) 简要介绍与源码分析

文章目录

前言

近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.

于是找到他们的开源产品体验了一把, 在线体验地址在 https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上: https://github.com/CompVis/stable-diffusion), 在搜索框中输入 “A dog flying in the sky” (一只狗在天空飞翔), 生成效果如下:

Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.

当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.

广而告之

可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.

另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.

总览

本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.

参考文章

在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:

  • Lilian 的博客, 内容非常非常详实, 干货十足, 而且每篇文章都极其用心, 向大佬学习: What are Diffusion Models?
  • ewrfcas 的知乎, 公式推导补充了更多的细节: 由浅入深了解Diffusion Model
  • Lilian 的博客, 介绍变分自动编码器 VAE: From Autoencoder to Beta-VAE, Diffusion Model 需要从分布中随机采样样本, 该过程无法求导, 需要使用到 VAE 中介绍的重参数技巧.
  • Denoising Diffusion Probabilistic Models 论文, - 其 TF 源码位于: https://github.com/hojonathanho/diffusion, 源码介绍以该版本为主- PyTorch 的开源实现: https://github.com/lucidrains/denoising-diffusion-pytorch, 核心逻辑和上面 Tensorflow 版本是一致的, Stable Diffusion 参考的是 pytorch 版本的代码.

扩散模型介绍

基本原理

Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.

具体来说, 前向阶段在原始图像

  1. x
  2. 0
  3. \mathbf{x}_0
  4. x0 上逐步增加噪声, 每一步得到的图像
  5. x
  6. t
  7. \mathbf{x}_t
  8. xt 只和上一步的结果
  9. x
  10. t
  11. 1
  12. \mathbf{x}_{t - 1}
  13. xt1 相关, 直至第
  14. T
  15. T
  16. T 步的图像
  17. x
  18. T
  19. \mathbf{x}_T
  20. xT 变为纯高斯噪声. 前向阶段图示如下:

而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声

  1. x
  2. T
  3. \mathbf{x}_T
  4. xT​, 通过逐步去噪, 直至最终将原图像
  5. x
  6. 0
  7. \mathbf{x}_0
  8. x0 给恢复出来, 逆向阶段图示如下:

模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式,

前向阶段

由于前向过程中图像

  1. x
  2. t
  3. \mathbf{x}_t
  4. xt 只和上一时刻的
  5. x
  6. t
  7. 1
  8. \mathbf{x}_{t - 1}
  9. xt1 有关, 该过程可以视为马尔科夫过程, 满足:
  10. q
  11. (
  12. x
  13. 1
  14. :
  15. T
  16. x
  17. 0
  18. )
  19. =
  20. t
  21. =
  22. 1
  23. T
  24. q
  25. (
  26. x
  27. t
  28. x
  29. t
  30. 1
  31. )
  32. q
  33. (
  34. x
  35. t
  36. x
  37. t
  38. 1
  39. )
  40. =
  41. N
  42. (
  43. x
  44. t
  45. ;
  46. 1
  47. β
  48. t
  49. x
  50. t
  51. 1
  52. ,
  53. β
  54. t
  55. I
  56. )
  57. ,
  58. \begin{align} q\left(x_{1: T} \mid x_0\right) &=\prod_{t=1}^T q\left(x_t \mid x_{t-1}\right) \\ q\left(x_t \mid x_{t-1}\right) &=\mathcal{N}\left(x_t ; \sqrt{1-\beta_t} x_{t-1}, \beta_t \mathbf{I}\right), \end{align}
  59. q(x1:T​∣x0​)q(xt​∣xt1​)​=t=1Tq(xt​∣xt1​)=N(xt​;1−βt​​xt1​,βtI),​​

其中

  1. β
  2. t
  3. (
  4. 0
  5. ,
  6. 1
  7. )
  8. \beta_t\in(0, 1)
  9. βt​∈(0,1) 为高斯分布的方差超参, 并满足
  10. β
  11. 1
  12. <
  13. β
  14. 2
  15. <
  16. <
  17. β
  18. T
  19. \beta_1 < \beta_2 < \ldots < \beta_T
  20. β1​<β2​<…<βT​. 另外公式 (2) 中为何均值
  21. x
  22. t
  23. 1
  24. x_{t-1}
  25. xt1 前乘上系数
  26. 1
  27. β
  28. t
  29. x
  30. t
  31. 1
  32. \sqrt{1-\beta_t} x_{t-1}
  33. 1−βt​​xt1 的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过 重参数技巧 采样得到
  34. x
  35. t
  36. x_t
  37. xt​.

重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布

  1. z
  2. N
  3. (
  4. z
  5. ;
  6. μ
  7. ,
  8. σ
  9. 2
  10. I
  11. )
  12. z \sim \mathcal{N}(z; \mu, \sigma^2\mathbf{I})
  13. zN(z;μ,σ2I) 中采样样本
  14. z
  15. z
  16. z, 可以通过引入随机变量
  17. ϵ
  18. N
  19. (
  20. 0
  21. ,
  22. I
  23. )
  24. \epsilon\sim\mathcal{N}(0, \mathbf{I})
  25. ϵ∼N(0,I), 使得
  26. z
  27. =
  28. μ
  29. +
  30. σ
  31. ϵ
  32. z = \mu + \sigma\odot\epsilon
  33. z=μ+σ⊙ϵ, 此时
  34. z
  35. z
  36. z 依旧具有随机性, 且服从高斯分布
  37. N
  38. (
  39. μ
  40. ,
  41. σ
  42. 2
  43. I
  44. )
  45. \mathcal{N}(\mu, \sigma^2\mathbf{I})
  46. N(μ,σ2I), 同时
  47. μ
  48. \mu
  49. μ
  50. σ
  51. \sigma
  52. σ (通常由网络生成) 可导.

简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样

  1. x
  2. t
  3. x_t
  4. xt 的方法, 即生成随机变量
  5. ϵ
  6. t
  7. N
  8. (
  9. 0
  10. ,
  11. I
  12. )
  13. \epsilon_t\sim\mathcal{N}(0, \mathbf{I})
  14. ϵt​∼N(0,I),

然后令

  1. α
  2. t
  3. =
  4. 1
  5. β
  6. t
  7. \alpha_t = 1 - \beta_t
  8. αt​=1−βt​, 以及
  9. α
  10. t
  11. =
  12. i
  13. =
  14. 1
  15. T
  16. α
  17. t
  18. \overline{\alpha_t} = \prod_{i=1}^{T}\alpha_t
  19. αt​​=∏i=1T​αt​, 从而可以得到:
  20. x
  21. t
  22. =
  23. 1
  24. β
  25. t
  26. x
  27. t
  28. 1
  29. +
  30. β
  31. t
  32. ϵ
  33. 1
  34. where
  35.   
  36. ϵ
  37. 1
  38. ,
  39. ϵ
  40. 2
  41. ,
  42. N
  43. (
  44. 0
  45. ,
  46. I
  47. )
  48. ,
  49.   
  50. reparameter trick
  51. ;
  52. =
  53. a
  54. t
  55. x
  56. t
  57. 1
  58. +
  59. 1
  60. α
  61. t
  62. ϵ
  63. 1
  64. =
  65. a
  66. t
  67. (
  68. a
  69. t
  70. 1
  71. x
  72. t
  73. 2
  74. +
  75. 1
  76. α
  77. t
  78. 1
  79. ϵ
  80. 2
  81. )
  82. +
  83. 1
  84. α
  85. t
  86. ϵ
  87. 1
  88. =
  89. a
  90. t
  91. a
  92. t
  93. 1
  94. x
  95. t
  96. 2
  97. +
  98. (
  99. a
  100. t
  101. (
  102. 1
  103. α
  104. t
  105. 1
  106. )
  107. ϵ
  108. 2
  109. +
  110. 1
  111. α
  112. t
  113. ϵ
  114. 1
  115. )
  116. =
  117. a
  118. t
  119. a
  120. t
  121. 1
  122. x
  123. t
  124. 2
  125. +
  126. 1
  127. α
  128. t
  129. α
  130. t
  131. 1
  132. ϵ
  133. ˉ
  134. 2
  135. where
  136. ϵ
  137. ˉ
  138. 2
  139. N
  140. (
  141. 0
  142. ,
  143. I
  144. )
  145. ;
  146. =
  147. =
  148. α
  149. ˉ
  150. t
  151. x
  152. 0
  153. +
  154. 1
  155. α
  156. ˉ
  157. t
  158. ϵ
  159. ˉ
  160. t
  161. .
  162. \begin{align} x_t &= \sqrt{1 - \beta_t} x_{t-1}+\beta_t \epsilon_1 \quad \text { where } \; \epsilon_1, \epsilon_2, \ldots \sim \mathcal{N}(0, \mathbf{I}), \; \text{reparameter trick} ; \nonumber \\ &=\sqrt{a_t} x_{t-1}+\sqrt{1-\alpha_t} \epsilon_1\nonumber \\ &=\sqrt{a_t}\left(\sqrt{a_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} \epsilon_2\right)+\sqrt{1-\alpha_t} \epsilon_1 \nonumber \\ &=\sqrt{a_t a_{t-1}} x_{t-2}+\left(\sqrt{a_t\left(1-\alpha_{t-1}\right)} \epsilon_2+\sqrt{1-\alpha_t} \epsilon_1\right) \tag{3-1} \\ &=\sqrt{a_t a_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \bar{\epsilon}_2 \quad \text { where } \quad \bar{\epsilon}_2 \sim \mathcal{N}(0, \mathbf{I}) ; \tag{3-2} \\ &=\ldots \nonumber \\ &=\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \bar{\epsilon}_t. \end{align}
  163. xt​​=1−βt​​xt1​+βt​ϵ1 where ϵ1​,ϵ2​,…∼N(0,I),reparameter trick;=at​​xt1​+1−αt​​ϵ1​=at​​(at1​​xt2​+1−αt1​​ϵ2​)+1−αt​​ϵ1​=atat1​​xt2​+(at​(1−αt1​)​ϵ2​+1−αt​​ϵ1​)=atat1​​xt2​+1−αt​αt1​​ϵˉ2 where ϵˉ2​∼N(0,I);=…=αˉt​​x0​+1−αˉt​​ϵˉt​.​(3-1)(3-2)​

其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有

  1. N
  2. (
  3. 0
  4. ,
  5. σ
  6. 1
  7. 2
  8. I
  9. )
  10. +
  11. N
  12. (
  13. 0
  14. ,
  15. σ
  16. 2
  17. 2
  18. I
  19. )
  20. N
  21. (
  22. 0
  23. ,
  24. (
  25. σ
  26. 1
  27. 2
  28. +
  29. σ
  30. 2
  31. 2
  32. )
  33. I
  34. )
  35. \mathcal{N}\left(0, \sigma_1^2\mathbf{I}\right) +\mathcal{N}\left(0,\sigma_2^2 \mathbf{I}\right)\sim\mathcal{N}\left(0, \left(\sigma_1^2 + \sigma_2^2\right)\mathbf{I}\right)
  36. N(012I)+N(022I)∼N(0,(σ12​+σ22​)I), 因此:
  37. a
  38. t
  39. (
  40. 1
  41. α
  42. t
  43. 1
  44. )
  45. ϵ
  46. 2
  47. N
  48. (
  49. 0
  50. ,
  51. a
  52. t
  53. (
  54. 1
  55. α
  56. t
  57. 1
  58. )
  59. I
  60. )
  61. 1
  62. α
  63. t
  64. ϵ
  65. 1
  66. N
  67. (
  68. 0
  69. ,
  70. (
  71. 1
  72. α
  73. t
  74. )
  75. I
  76. )
  77. a
  78. t
  79. (
  80. 1
  81. α
  82. t
  83. 1
  84. )
  85. ϵ
  86. 2
  87. +
  88. 1
  89. α
  90. t
  91. ϵ
  92. 1
  93. N
  94. (
  95. 0
  96. ,
  97. [
  98. α
  99. t
  100. (
  101. 1
  102. α
  103. t
  104. 1
  105. )
  106. +
  107. (
  108. 1
  109. α
  110. t
  111. )
  112. ]
  113. I
  114. )
  115. =
  116. N
  117. (
  118. 0
  119. ,
  120. (
  121. 1
  122. α
  123. t
  124. α
  125. t
  126. 1
  127. )
  128. I
  129. )
  130. .
  131. \begin{aligned} &\sqrt{a_t\left(1-\alpha_{t-1}\right)} \epsilon_2 \sim \mathcal{N}\left(0, a_t\left(1-\alpha_{t-1}\right) \mathbf{I}\right) \\ &\sqrt{1-\alpha_t} \epsilon_1 \sim \mathcal{N}\left(0,\left(1-\alpha_t\right) \mathbf{I}\right) \\ &\sqrt{a_t\left(1-\alpha_{t-1}\right)} \epsilon_2+\sqrt{1-\alpha_t} \epsilon_1 \sim \mathcal{N}\left(0,\left[\alpha_t\left(1-\alpha_{t-1}\right)+\left(1-\alpha_t\right)\right] \mathbf{I}\right) \\ &=\mathcal{N}\left(0,\left(1-\alpha_t \alpha_{t-1}\right) \mathbf{I}\right) . \end{aligned}
  132. at​(1−αt1​)​ϵ2​∼N(0,at​(1−αt1​)I)1−αt​​ϵ1​∼N(0,(1−αt​)I)at​(1−αt1​)​ϵ2​+1−αt​​ϵ1​∼N(0,[αt​(1−αt1​)+(1−αt​)]I)=N(0,(1−αt​αt1​)I).​

注意公式 (3-2) 中

  1. ϵ
  2. ˉ
  3. 2
  4. N
  5. (
  6. 0
  7. ,
  8. I
  9. )
  10. \bar{\epsilon}_2 \sim \mathcal{N}(0, \mathbf{I})
  11. ϵˉ2​∼N(0,I), 因此还需乘上
  12. 1
  13. α
  14. t
  15. α
  16. t
  17. 1
  18. \sqrt{1-\alpha_t \alpha_{t-1}}
  19. 1−αt​αt1​​. 从公式 (3) 可以看出
  20. q
  21. (
  22. x
  23. t
  24. x
  25. 0
  26. )
  27. =
  28. N
  29. (
  30. x
  31. t
  32. ;
  33. a
  34. ˉ
  35. t
  36. x
  37. 0
  38. ,
  39. (
  40. 1
  41. a
  42. ˉ
  43. t
  44. )
  45. I
  46. )
  47. \begin{align} q\left(x_t \mid x_0\right)=\mathcal{N}\left(x_t ; \sqrt{\bar{a}_t} x_0,\left(1-\bar{a}_t\right) \mathbf{I}\right) \end{align}
  48. q(xt​∣x0​)=N(xt​;aˉt​​x0​,(1aˉt​)I)​​

注意由于

  1. β
  2. t
  3. (
  4. 0
  5. ,
  6. 1
  7. )
  8. \beta_t\in(0, 1)
  9. βt​∈(0,1)
  10. β
  11. 1
  12. <
  13. <
  14. β
  15. T
  16. \beta_1 < \ldots < \beta_T
  17. β1​<…<βT​,
  18. α
  19. t
  20. =
  21. 1
  22. β
  23. t
  24. \alpha_t = 1 - \beta_t
  25. αt​=1−βt​, 因此
  26. α
  27. t
  28. (
  29. 0
  30. ,
  31. 1
  32. )
  33. \alpha_t\in(0, 1)
  34. αt​∈(0,1) 并且有
  35. α
  36. 1
  37. >
  38. >
  39. α
  40. T
  41. \alpha_1 > \ldots>\alpha_T
  42. α1​>…>αT​, 另外由于
  43. α
  44. ˉ
  45. t
  46. =
  47. i
  48. =
  49. 1
  50. T
  51. α
  52. t
  53. \bar{\alpha}_t=\prod_{i=1}^T\alpha_t
  54. αˉt​=∏i=1T​αt​, 因此当
  55. T
  56. T\rightarrow\infty
  57. T→∞ 时,
  58. α
  59. ˉ
  60. t
  61. 0
  62. \bar{\alpha}_t\rightarrow0
  63. αˉt​→0 以及
  64. (
  65. 1
  66. a
  67. ˉ
  68. t
  69. )
  70. 1
  71. (1-\bar{a}_t)\rightarrow 1
  72. (1aˉt​)→1, 此时
  73. x
  74. T
  75. N
  76. (
  77. 0
  78. ,
  79. I
  80. )
  81. x_T\sim\mathcal{N}(0, \mathbf{I})
  82. xT​∼N(0,I). 从这里的推导来看, 在公式 (2) 中的均值
  83. x
  84. t
  85. 1
  86. x_{t-1}
  87. xt1 前乘上系数
  88. 1
  89. β
  90. t
  91. x
  92. t
  93. 1
  94. \sqrt{1-\beta_t} x_{t-1}
  95. 1−βt​​xt1 会使得
  96. x
  97. T
  98. x_{T}
  99. xT 最后收敛到标准高斯分布.

逆向阶段

前向阶段是加噪声的过程, 而逆向阶段则是将噪声去除, 如果能得到逆向过程的分布

  1. q
  2. (
  3. x
  4. t
  5. 1
  6. x
  7. t
  8. )
  9. q\left(x_{t-1} \mid x_t\right)
  10. q(xt1​∣xt​), 那么通过输入高斯噪声
  11. x
  12. T
  13. N
  14. (
  15. 0
  16. ,
  17. I
  18. )
  19. x_T\sim\mathcal{N}(0, \mathbf{I})
  20. xT​∼N(0,I), 我们将生成一个真实的样本. 注意到当
  21. β
  22. t
  23. \beta_t
  24. βt 足够小时,
  25. q
  26. (
  27. x
  28. t
  29. 1
  30. x
  31. t
  32. )
  33. q\left(x_{t-1} \mid x_t\right)
  34. q(xt1​∣xt​) 也是高斯分布, 具体的证明在 ewrfcas 的知乎文章: 由浅入深了解Diffusion Model 推荐的论文中:
  1. On the theory of stochastic processes, with particular reference to applications

. 我大致看了一下, 哈哈, 没太看明白, 不过想到这个不是我关注的重点, 因此 pass. 由于我们无法直接推断

  1. q
  2. (
  3. x
  4. t
  5. 1
  6. x
  7. t
  8. )
  9. q\left(x_{t-1} \mid x_t\right)
  10. q(xt1​∣xt​), 因此我们将使用深度学习模型
  11. p
  12. θ
  13. p_{\theta}
  14. pθ​ 去拟合分布
  15. q
  16. (
  17. x
  18. t
  19. 1
  20. x
  21. t
  22. )
  23. q\left(x_{t-1} \mid x_t\right)
  24. q(xt1​∣xt​), 模型参数为
  25. θ
  26. \theta
  27. θ:
  28. p
  29. θ
  30. (
  31. x
  32. 0
  33. :
  34. T
  35. )
  36. =
  37. p
  38. (
  39. x
  40. T
  41. )
  42. t
  43. =
  44. 1
  45. T
  46. p
  47. θ
  48. (
  49. x
  50. t
  51. 1
  52. x
  53. t
  54. )
  55. p
  56. θ
  57. (
  58. x
  59. t
  60. 1
  61. x
  62. t
  63. )
  64. =
  65. N
  66. (
  67. x
  68. t
  69. 1
  70. ;
  71. μ
  72. θ
  73. (
  74. x
  75. t
  76. ,
  77. t
  78. )
  79. ,
  80. Σ
  81. θ
  82. (
  83. x
  84. t
  85. ,
  86. t
  87. )
  88. )
  89. \begin{align} p_\theta\left(x_{0: T}\right) &=p\left(x_T\right) \prod_{t=1}^T p_\theta\left(x_{t-1} \mid x_t\right) \\ p_\theta\left(x_{t-1} \mid x_t\right) &=\mathcal{N}\left(x_{t-1} ; \mu_\theta\left(x_t, t\right), \Sigma_\theta\left(x_t, t\right)\right) \end{align}
  90. pθ​(x0:T​)pθ​(xt1​∣xt​)​=p(xT​)t=1Tpθ​(xt1​∣xt​)=N(xt1​;μθ​(xt​,t),Σθ​(xt​,t))​​

注意到, 虽然我们无法直接求得

  1. q
  2. (
  3. x
  4. t
  5. 1
  6. x
  7. t
  8. )
  9. q\left(x_{t-1} \mid x_t\right)
  10. q(xt1​∣xt​) (注意这里是
  11. q
  12. q
  13. q 而不是模型
  14. p
  15. θ
  16. p_{\theta}
  17. pθ​), 但在知道
  18. x
  19. 0
  20. x_0
  21. x0 的情况下, 可以通过贝叶斯公式得到
  22. q
  23. (
  24. x
  25. t
  26. 1
  27. x
  28. t
  29. ,
  30. x
  31. 0
  32. )
  33. q\left(x_{t-1} \mid x_t, x_0\right)
  34. q(xt1​∣xt​,x0​) 为:
  35. q
  36. (
  37. x
  38. t
  39. 1
  40. x
  41. t
  42. ,
  43. x
  44. 0
  45. )
  46. =
  47. N
  48. (
  49. x
  50. t
  51. 1
  52. ;
  53. μ
  54. ~
  55. (
  56. x
  57. t
  58. ,
  59. x
  60. 0
  61. )
  62. ,
  63. β
  64. ~
  65. t
  66. I
  67. )
  68. \begin{align} q\left(x_{t-1} \mid x_t, x_0\right) &= \mathcal{N}\left(x_{t-1} ; {\color{blue}{\tilde{\mu}}(x_t, x_0)}, {\color{red}{\tilde{\beta}_t} \mathbf{I}}\right) \end{align}
  69. q(xt1​∣xt​,x0​)​=N(xt1​;μ~​(xt​,x0​),β~​tI)​​

推导过程如下:

  1. q
  2. (
  3. x
  4. t
  5. 1
  6. x
  7. t
  8. ,
  9. x
  10. 0
  11. )
  12. =
  13. q
  14. (
  15. x
  16. t
  17. x
  18. t
  19. 1
  20. ,
  21. x
  22. 0
  23. )
  24. q
  25. (
  26. x
  27. t
  28. 1
  29. x
  30. 0
  31. )
  32. q
  33. (
  34. x
  35. t
  36. x
  37. 0
  38. )
  39. exp
  40. (
  41. 1
  42. 2
  43. (
  44. (
  45. x
  46. t
  47. α
  48. t
  49. x
  50. t
  51. 1
  52. )
  53. 2
  54. β
  55. t
  56. +
  57. (
  58. x
  59. t
  60. 1
  61. α
  62. ˉ
  63. t
  64. 1
  65. x
  66. 0
  67. )
  68. 2
  69. 1
  70. α
  71. ˉ
  72. t
  73. 1
  74. (
  75. x
  76. t
  77. α
  78. ˉ
  79. t
  80. x
  81. 0
  82. )
  83. 2
  84. 1
  85. α
  86. ˉ
  87. t
  88. )
  89. )
  90. =
  91. exp
  92. (
  93. 1
  94. 2
  95. (
  96. x
  97. t
  98. 2
  99. 2
  100. α
  101. t
  102. x
  103. t
  104. x
  105. t
  106. 1
  107. +
  108. α
  109. t
  110. x
  111. t
  112. 1
  113. 2
  114. β
  115. t
  116. +
  117. x
  118. t
  119. 1
  120. 2
  121. 2
  122. α
  123. ˉ
  124. t
  125. 1
  126. x
  127. 0
  128. x
  129. t
  130. 1
  131. +
  132. α
  133. ˉ
  134. t
  135. 1
  136. x
  137. 0
  138. 2
  139. 1
  140. α
  141. ˉ
  142. t
  143. 1
  144. (
  145. x
  146. t
  147. α
  148. ˉ
  149. t
  150. x
  151. 0
  152. )
  153. 2
  154. 1
  155. α
  156. ˉ
  157. t
  158. )
  159. )
  160. =
  161. exp
  162. (
  163. 1
  164. 2
  165. (
  166. (
  167. α
  168. t
  169. β
  170. t
  171. +
  172. 1
  173. 1
  174. α
  175. ˉ
  176. t
  177. 1
  178. )
  179. x
  180. t
  181. 1
  182. 2
  183. x
  184. t
  185. 1
  186. 方差
  187. (
  188. 2
  189. α
  190. t
  191. β
  192. t
  193. x
  194. t
  195. +
  196. 2
  197. α
  198. ˉ
  199. t
  200. 1
  201. 1
  202. α
  203. ˉ
  204. t
  205. 1
  206. x
  207. 0
  208. )
  209. x
  210. t
  211. 1
  212. x
  213. t
  214. 1
  215. 均值
  216. +
  217. C
  218. (
  219. x
  220. t
  221. ,
  222. x
  223. 0
  224. )
  225. x
  226. t
  227. 1
  228. 无关
  229. )
  230. )
  231. \begin{aligned} q(x_{t-1} \vert x_t, x_0) &= q(x_t \vert x_{t-1}, x_0) \frac{ q(x_{t-1} \vert x_0) }{ q(x_t \vert x_0) } \\ &\propto \exp \Big(-\frac{1}{2} \big(\frac{(x_t - \sqrt{\alpha_t} x_{t-1})^2}{\beta_t} + \frac{(x_{t-1} - \sqrt{\bar{\alpha}_{t-1}} x_0)^2}{1-\bar{\alpha}_{t-1}} - \frac{(x_t - \sqrt{\bar{\alpha}_t} x_0)^2}{1-\bar{\alpha}_t} \big) \Big) \\ &= \exp \Big(-\frac{1}{2} \big(\frac{x_t^2 - 2\sqrt{\alpha_t} x_t \color{blue}{x_{t-1}} \color{black}{+ \alpha_t} \color{red}{x_{t-1}^2} }{\beta_t} + \frac{ \color{red}{x_{t-1}^2} \color{black}{- 2 \sqrt{\bar{\alpha}_{t-1}} x_0} \color{blue}{x_{t-1}} \color{black}{+ \bar{\alpha}_{t-1} x_0^2} }{1-\bar{\alpha}_{t-1}} - \frac{(x_t - \sqrt{\bar{\alpha}_t} x_0)^2}{1-\bar{\alpha}_t} \big) \Big) \\ &= \exp\Big( -\frac{1}{2} \big( \underbrace{\color{red}{(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}})} x_{t-1}^2}_{x_{t-1} \text { 方差 }} - \underbrace{\color{blue}{(\frac{2\sqrt{\alpha_t}}{\beta_t} x_t + \frac{2\sqrt{\bar{\alpha}_{t-1}}}{1 - \bar{\alpha}_{t-1}} x_0)} x_{t-1}}_{x_{t-1} \text { 均值 }} + \underbrace{{\color{black}{ C(x_t, x_0)}}}_{\text {与 } x_{t-1} \text { 无关 }} \big) \Big) \end{aligned}
  232. q(xt1​∣xt​,x0​)​=q(xt​∣xt1​,x0​)q(xt​∣x0​)q(xt1​∣x0​)​∝exp(−21​(βt​(xt​−αt​​xt1​)2​+1−αˉt1​(xt1​−αˉt1​​x0​)2​−1−αˉt​(xt​−αˉt​​x0​)2​))=exp(−21​(βtxt2​−2αt​​xtxt1​+αtxt12​​+1−αˉt1xt12​−2αˉt1​​x0xt1​+αˉt1x02​​−1−αˉt​(xt​−αˉt​​x0​)2​))=exp(−21​(xt1 方差 t​αt​​+1−αˉt11​)xt12​​​−xt1 均值 t2αt​​​xt​+1−αˉt12αˉt1​​​x0​)xt1​​​+与 xt1 无关 C(xt​,x0​)​​))​

上面推导过程中, 通过贝叶斯公式巧妙的将逆向过程转换为前向过程, 且最终得到的概率密度函数和高斯概率密度函数的指数部分

  1. exp
  2. (
  3. (
  4. x
  5. μ
  6. )
  7. 2
  8. 2
  9. σ
  10. 2
  11. )
  12. =
  13. exp
  14. (
  15. 1
  16. 2
  17. (
  18. 1
  19. σ
  20. 2
  21. x
  22. 2
  23. 2
  24. μ
  25. σ
  26. 2
  27. x
  28. +
  29. μ
  30. 2
  31. σ
  32. 2
  33. )
  34. )
  35. \exp{\left(-\frac{\left(x - \mu\right)^2}{2\sigma^2}\right)} = \exp{\left(-\frac{1}{2}\left(\frac{1}{\sigma^2}x^2 - \frac{2\mu}{\sigma^2}x + \frac{\mu^2}{\sigma^2}\right)\right)}
  36. exp(−2σ2(x−μ)2​)=exp(−21​(σ21x2−σ22μ​x2μ2​)) 能对应, 即有:
  37. β
  38. ~
  39. t
  40. =
  41. 1
  42. /
  43. (
  44. α
  45. t
  46. β
  47. t
  48. +
  49. 1
  50. 1
  51. α
  52. ˉ
  53. t
  54. 1
  55. )
  56. =
  57. 1
  58. /
  59. (
  60. α
  61. t
  62. α
  63. ˉ
  64. t
  65. +
  66. β
  67. t
  68. β
  69. t
  70. (
  71. 1
  72. α
  73. ˉ
  74. t
  75. 1
  76. )
  77. )
  78. =
  79. 1
  80. α
  81. ˉ
  82. t
  83. 1
  84. 1
  85. α
  86. ˉ
  87. t
  88. β
  89. t
  90. μ
  91. ~
  92. t
  93. (
  94. x
  95. t
  96. ,
  97. x
  98. 0
  99. )
  100. =
  101. (
  102. α
  103. t
  104. β
  105. t
  106. x
  107. t
  108. +
  109. α
  110. ˉ
  111. t
  112. 1
  113. 1
  114. α
  115. ˉ
  116. t
  117. 1
  118. x
  119. 0
  120. )
  121. /
  122. (
  123. α
  124. t
  125. β
  126. t
  127. +
  128. 1
  129. 1
  130. α
  131. ˉ
  132. t
  133. 1
  134. )
  135. =
  136. (
  137. α
  138. t
  139. β
  140. t
  141. x
  142. t
  143. +
  144. α
  145. ˉ
  146. t
  147. 1
  148. 1
  149. α
  150. ˉ
  151. t
  152. 1
  153. x
  154. 0
  155. )
  156. 1
  157. α
  158. ˉ
  159. t
  160. 1
  161. 1
  162. α
  163. ˉ
  164. t
  165. β
  166. t
  167. =
  168. α
  169. t
  170. (
  171. 1
  172. α
  173. ˉ
  174. t
  175. 1
  176. )
  177. 1
  178. α
  179. ˉ
  180. t
  181. x
  182. t
  183. +
  184. α
  185. ˉ
  186. t
  187. 1
  188. β
  189. t
  190. 1
  191. α
  192. ˉ
  193. t
  194. x
  195. 0
  196. \begin{align} \tilde{\beta}_t &= 1/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}}) = 1/(\frac{\alpha_t - \bar{\alpha}_t + \beta_t}{\beta_t(1 - \bar{\alpha}_{t-1})}) = \color{green}{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t} \\ \tilde{\mu}_t (x_t, x_0) &= (\frac{\sqrt{\alpha_t}}{\beta_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1} }}{1 - \bar{\alpha}_{t-1}} x_0)/(\frac{\alpha_t}{\beta_t} + \frac{1}{1 - \bar{\alpha}_{t-1}}) \nonumber\\ &= (\frac{\sqrt{\alpha_t}}{\beta_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1} }}{1 - \bar{\alpha}_{t-1}} x_0) \color{green}{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t}\nonumber \\ &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} x_0\\ \end{align}
  197. β~​t​μ~​t​(xt​,x0​)​=1/(βt​αt​​+1−αˉt11​)=1/(βt​(1−αˉt1​)αt​−αˉt​+βt​​)=1−αˉt1−αˉt1​​⋅βt​=(βt​αt​​​xt​+1−αˉt1​αˉt1​​​x0​)/(βt​αt​​+1−αˉt11​)=(βt​αt​​​xt​+1−αˉt1​αˉt1​​​x0​)1−αˉt1−αˉt1​​⋅βt​=1−αˉt​αt​​(1−αˉt1​)​xt​+1−αˉt​αˉt1​​βt​​x0​​​

通过公式 (8) 和公式 (9), 我们能得到

  1. q
  2. (
  3. x
  4. t
  5. 1
  6. x
  7. t
  8. ,
  9. x
  10. 0
  11. )
  12. q\left(x_{t-1} \mid x_t, x_0\right)
  13. q(xt1​∣xt​,x0​) (见公式 (7)) 的分布. 此外由于公式 (3) 揭示的
  14. x
  15. t
  16. x_t
  17. xt
  18. x
  19. 0
  20. x_0
  21. x0 之间的关系:
  22. x
  23. t
  24. =
  25. α
  26. ˉ
  27. t
  28. x
  29. 0
  30. +
  31. 1
  32. α
  33. ˉ
  34. t
  35. ϵ
  36. ˉ
  37. t
  38. x_t =\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \bar{\epsilon}_t
  39. xt​=αˉt​​x0​+1−αˉt​​ϵˉt​, 可以得到
  40. x
  41. 0
  42. =
  43. 1
  44. α
  45. ˉ
  46. t
  47. (
  48. x
  49. t
  50. 1
  51. α
  52. ˉ
  53. t
  54. ϵ
  55. t
  56. )
  57. \begin{align} x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1 - \bar{\alpha}_t}\epsilon_t) \end{align}
  58. x0​=αˉt​​1​(xt​−1−αˉt​​ϵt​)​​

代入公式 (9) 中得到:

  1. μ
  2. ~
  3. t
  4. =
  5. α
  6. t
  7. (
  8. 1
  9. α
  10. ˉ
  11. t
  12. 1
  13. )
  14. 1
  15. α
  16. ˉ
  17. t
  18. x
  19. t
  20. +
  21. α
  22. ˉ
  23. t
  24. 1
  25. β
  26. t
  27. 1
  28. α
  29. ˉ
  30. t
  31. 1
  32. α
  33. ˉ
  34. t
  35. (
  36. x
  37. t
  38. 1
  39. α
  40. ˉ
  41. t
  42. ϵ
  43. t
  44. )
  45. =
  46. 1
  47. α
  48. t
  49. (
  50. x
  51. t
  52. 1
  53. α
  54. t
  55. 1
  56. α
  57. ˉ
  58. t
  59. ϵ
  60. t
  61. )
  62. \begin{align} \tilde{\mu}_t &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1 - \bar{\alpha}_t}\epsilon_t)\nonumber \\ &= \color{cyan}{\frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_t \Big)} \end{align}
  63. μ~​t​​=1−αˉt​αt​​(1−αˉt1​)​xt​+1−αˉt​αˉt1​​βt​​αˉt​​1​(xt​−1−αˉt​​ϵt​)=αt​​1​(xt​−1−αˉt​​1−αt​​ϵt​)​​

补充一下公式 (11) 的详细推导过程:

前面说到, 我们将使用深度学习模型

  1. p
  2. θ
  3. p_{\theta}
  4. pθ​ 去拟合逆向过程的分布
  5. q
  6. (
  7. x
  8. t
  9. 1
  10. x
  11. t
  12. )
  13. q\left(x_{t-1} \mid x_t\right)
  14. q(xt1​∣xt​), 由公式 (6)
  15. p
  16. θ
  17. (
  18. x
  19. t
  20. 1
  21. x
  22. t
  23. )
  24. =
  25. N
  26. (
  27. x
  28. t
  29. 1
  30. ;
  31. μ
  32. θ
  33. (
  34. x
  35. t
  36. ,
  37. t
  38. )
  39. ,
  40. Σ
  41. θ
  42. (
  43. x
  44. t
  45. ,
  46. t
  47. )
  48. )
  49. p_\theta\left(x_{t-1} \mid x_t\right) =\mathcal{N}\left(x_{t-1} ; \mu_\theta\left(x_t, t\right), \Sigma_\theta\left(x_t, t\right)\right)
  50. pθ​(xt1​∣xt​)=N(xt1​;μθ​(xt​,t),Σθ​(xt​,t)), 我们希望训练模型
  51. μ
  52. θ
  53. (
  54. x
  55. t
  56. ,
  57. t
  58. )
  59. \mu_\theta\left(x_t, t\right)
  60. μθ​(xt​,t) 以预估
  61. μ
  62. ~
  63. t
  64. =
  65. 1
  66. α
  67. t
  68. (
  69. x
  70. t
  71. 1
  72. α
  73. t
  74. 1
  75. α
  76. ˉ
  77. t
  78. ϵ
  79. t
  80. )
  81. \tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_t \Big)
  82. μ~​t​=αt​​1​(xt​−1−αˉt​​1−αt​​ϵt​). 由于
  83. x
  84. t
  85. x_t
  86. xt 在训练阶段会作为输入, 因此它是已知的, 我们可以转而让模型去预估噪声
  87. ϵ
  88. t
  89. \epsilon_t
  90. ϵt​, 即令:
  91. μ
  92. θ
  93. (
  94. x
  95. t
  96. ,
  97. t
  98. )
  99. =
  100. 1
  101. α
  102. t
  103. (
  104. x
  105. t
  106. 1
  107. α
  108. t
  109. 1
  110. α
  111. ˉ
  112. t
  113. ϵ
  114. θ
  115. (
  116. x
  117. t
  118. ,
  119. t
  120. )
  121. )
  122. Thus
  123. x
  124. t
  125. 1
  126. =
  127. N
  128. (
  129. x
  130. t
  131. 1
  132. ;
  133. 1
  134. α
  135. t
  136. (
  137. x
  138. t
  139. 1
  140. α
  141. t
  142. 1
  143. α
  144. ˉ
  145. t
  146. ϵ
  147. θ
  148. (
  149. x
  150. t
  151. ,
  152. t
  153. )
  154. )
  155. ,
  156. Σ
  157. θ
  158. (
  159. x
  160. t
  161. ,
  162. t
  163. )
  164. )
  165. \begin{align} \mu_\theta(x_t, t) &= \color{cyan}{\frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \Big)} \\ \text{Thus }x_{t-1} &= \mathcal{N}(x_{t-1}; \frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \Big), \boldsymbol{\Sigma}_\theta(x_t, t)) \end{align}
  166. μθ​(xt​,t)Thus xt1​​=αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t))=N(xt1​;αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t)),Σθ​(xt​,t))​​

模型训练

前面谈到, 逆向阶段让模型去预估噪声

  1. ϵ
  2. θ
  3. (
  4. x
  5. t
  6. ,
  7. t
  8. )
  9. \epsilon_\theta(x_t, t)
  10. ϵθ​(xt​,t), 那么应该如何设计 Loss 函数 ? 我们的目标是在真实数据分布下, 最大化模型预测分布的对数似然, 即优化在
  11. x
  12. 0
  13. q
  14. (
  15. x
  16. 0
  17. )
  18. x_0\sim q(x_0)
  19. x0​∼q(x0​) 下的
  20. p
  21. θ
  22. (
  23. x
  24. 0
  25. )
  26. p_\theta(x_0)
  27. pθ​(x0​) 交叉熵:
  28. L
  29. =
  30. E
  31. q
  32. (
  33. x
  34. 0
  35. )
  36. [
  37. log
  38. p
  39. θ
  40. (
  41. x
  42. 0
  43. )
  44. ]
  45. \begin{align} \mathcal{L} = \mathbb{E}_{q(x_0)}\left[-\log{p_\theta(x_0)}\right] \end{align}
  46. L=Eq(x0​)​[−logpθ​(x0​)]​​

和 变分自动编码器 VAE 类似, 使用 Variational Lower Bound 来优化:

  1. log
  2. p
  3. θ
  4. (
  5. x
  6. 0
  7. )
  8. -\log{p_\theta(x_0)}
  9. logpθ​(x0​) :
  10. log
  11. p
  12. θ
  13. (
  14. x
  15. 0
  16. )
  17. log
  18. p
  19. θ
  20. (
  21. x
  22. 0
  23. )
  24. +
  25. D
  26. K
  27. L
  28. (
  29. q
  30. (
  31. x
  32. 1
  33. :
  34. T
  35. x
  36. 0
  37. )
  38. p
  39. θ
  40. (
  41. x
  42. 1
  43. :
  44. T
  45. x
  46. 0
  47. )
  48. )
  49. ;
  50. 注: 注意KL散度非负
  51. =
  52. log
  53. p
  54. θ
  55. (
  56. x
  57. 0
  58. )
  59. +
  60. E
  61. q
  62. (
  63. x
  64. 1
  65. :
  66. T
  67. x
  68. 0
  69. )
  70. [
  71. log
  72. q
  73. (
  74. x
  75. 1
  76. :
  77. T
  78. x
  79. 0
  80. )
  81. p
  82. θ
  83. (
  84. x
  85. 0
  86. :
  87. T
  88. )
  89. /
  90. p
  91. θ
  92. (
  93. x
  94. 0
  95. )
  96. ]
  97. ;
  98.   
  99. where
  100.   
  101. p
  102. θ
  103. (
  104. x
  105. 1
  106. :
  107. T
  108. x
  109. 0
  110. )
  111. =
  112. p
  113. θ
  114. (
  115. x
  116. 0
  117. :
  118. T
  119. )
  120. p
  121. θ
  122. (
  123. x
  124. 0
  125. )
  126. =
  127. log
  128. p
  129. θ
  130. (
  131. x
  132. 0
  133. )
  134. +
  135. E
  136. q
  137. (
  138. x
  139. 1
  140. :
  141. T
  142. x
  143. 0
  144. )
  145. [
  146. log
  147. q
  148. (
  149. x
  150. 1
  151. :
  152. T
  153. x
  154. 0
  155. )
  156. p
  157. θ
  158. (
  159. x
  160. 0
  161. :
  162. T
  163. )
  164. +
  165. log
  166. p
  167. θ
  168. (
  169. x
  170. 0
  171. )
  172. q无关
  173. ]
  174. =
  175. E
  176. q
  177. (
  178. x
  179. 1
  180. :
  181. T
  182. x
  183. 0
  184. )
  185. [
  186. log
  187. q
  188. (
  189. x
  190. 1
  191. :
  192. T
  193. x
  194. 0
  195. )
  196. p
  197. θ
  198. (
  199. x
  200. 0
  201. :
  202. T
  203. )
  204. ]
  205. .
  206. \begin{align} -\log p_\theta\left(x_0\right) &\leq-\log p_\theta\left(x_0\right)+D_{K L}\left(q\left(x_{1: T} \mid x_0\right) \| p_\theta\left(x_{1: T} \mid x_0\right)\right); \quad \text{注: 注意KL散度非负}\nonumber\\ &=-\log p_\theta\left(x_0\right)+\mathbb{E}_{q\left(x_{1: T} \mid x_0\right)}\left[\log \frac{q\left(x_{1: T} \mid x_0\right)}{p_\theta\left(x_{0: T}\right) / p_\theta\left(x_0\right)}\right] ; \; \text { where } \; p_\theta\left(x_{1: T} \mid x_0\right)=\frac{p_\theta\left(x_{0: T}\right)}{p_\theta\left(x_0\right)}\nonumber\\ &=-\log p_\theta\left(x_0\right)+\mathbb{E}_{q\left(x_{1: T} \mid x_0\right)}[\log \frac{q\left(x_{1: T} \mid x_0\right)}{p_\theta\left(x_{0: T}\right)}+\underbrace{\log p_\theta\left(x_0\right)}_{\text {与q无关 }}]\nonumber\\ &=\mathbb{E}_{q\left(x_{1: T} \mid x_0\right)}\left[\log \frac{q\left(x_{1: T} \mid x_0\right)}{p_\theta\left(x_{0: T}\right)}\right] . \end{align}
  207. logpθ​(x0​)​≤−logpθ​(x0​)+DKL​(q(x1:T​∣x0​)∥pθ​(x1:T​∣x0​));注: 注意KL散度非负=−logpθ​(x0​)+Eq(x1:T​∣x0​)​[logpθ​(x0:T​)/pθ​(x0​)q(x1:T​∣x0​)​]; where pθ​(x1:T​∣x0​)=pθ​(x0​)pθ​(x0:T​)​=−logpθ​(x0​)+Eq(x1:T​∣x0​)​[logpθ​(x0:T​)q(x1:T​∣x0​)​+与q无关 logpθ​(x0​)​​]=Eq(x1:T​∣x0​)​[logpθ​(x0:T​)q(x1:T​∣x0​)​].​​

对公式 (15) 左右两边取期望

  1. E
  2. q
  3. (
  4. x
  5. 0
  6. )
  7. \mathbb{E}_{q(x_0)}
  8. Eq(x0​)​, 利用到重积分中的 Fubini 定理 可得:
  9. L
  10. V
  11. L
  12. B
  13. =
  14. E
  15. q
  16. (
  17. x
  18. 0
  19. )
  20. (
  21. E
  22. q
  23. (
  24. x
  25. 1
  26. :
  27. T
  28. x
  29. 0
  30. )
  31. [
  32. log
  33. q
  34. (
  35. x
  36. 1
  37. :
  38. T
  39. x
  40. 0
  41. )
  42. p
  43. θ
  44. (
  45. x
  46. 0
  47. :
  48. T
  49. )
  50. ]
  51. )
  52. =
  53. E
  54. q
  55. (
  56. x
  57. 0
  58. :
  59. T
  60. )
  61. [
  62. log
  63. q
  64. (
  65. x
  66. 1
  67. :
  68. T
  69. x
  70. 0
  71. )
  72. p
  73. θ
  74. (
  75. x
  76. 0
  77. :
  78. T
  79. )
  80. ]
  81. Fubini定理
  82. E
  83. q
  84. (
  85. x
  86. 0
  87. )
  88. [
  89. log
  90. p
  91. θ
  92. (
  93. x
  94. 0
  95. )
  96. ]
  97. \mathcal{L}_{V L B}=\underbrace{\mathbb{E}_{q\left(x_0\right)}\left(\mathbb{E}_{q\left(x_{1: T} \mid x_0\right)}\left[\log \frac{q\left(x_{1: T} \mid x_0\right)}{p_\theta\left(x_{0: T}\right)}\right]\right)=\mathbb{E}_{q\left(x_{0: T}\right)}\left[\log \frac{q\left(x_{1: T} \mid x_0\right)}{p_\theta\left(x_{0: T}\right)}\right]}_{\text {Fubini定理 }} \geq \mathbb{E}_{q\left(x_0\right)}\left[-\log p_\theta\left(x_0\right)\right]
  98. LVLB​=Fubini定理 Eq(x0​)​(Eq(x1:T​∣x0​)​[logpθ​(x0:T​)q(x1:T​∣x0​)​])=Eq(x0:T​)​[logpθ​(x0:T​)q(x1:T​∣x0​)​]​​≥Eq(x0​)​[−logpθ​(x0​)]

因此最小化

  1. L
  2. V
  3. L
  4. B
  5. \mathcal{L}_{V L B}
  6. LVLB 就可以优化公式 (14) 中的目标函数. 之后对
  7. L
  8. V
  9. L
  10. B
  11. \mathcal{L}_{V L B}
  12. LVLB 做进一步的推导, 这部分的详细推导见上面的参考文章, 最终的结论是:
  13. L
  14. V
  15. L
  16. B
  17. =
  18. L
  19. T
  20. +
  21. L
  22. T
  23. 1
  24. +
  25. +
  26. L
  27. 0
  28. L
  29. T
  30. =
  31. D
  32. K
  33. L
  34. (
  35. q
  36. (
  37. x
  38. T
  39. x
  40. 0
  41. )
  42. p
  43. θ
  44. (
  45. x
  46. T
  47. )
  48. )
  49. L
  50. t
  51. =
  52. D
  53. K
  54. L
  55. (
  56. q
  57. (
  58. x
  59. t
  60. x
  61. t
  62. +
  63. 1
  64. ,
  65. x
  66. 0
  67. )
  68. p
  69. θ
  70. (
  71. x
  72. t
  73. x
  74. t
  75. +
  76. 1
  77. )
  78. )
  79. ;
  80. 1
  81. t
  82. T
  83. 1
  84. L
  85. 0
  86. =
  87. log
  88. p
  89. θ
  90. (
  91. x
  92. 0
  93. x
  94. 1
  95. )
  96. \begin{align} \mathcal{L}_{V L B} &= L_T + L_{T - 1} + \ldots + L_0 \\ L_T &= D_{KL}\left(q(x_T|x_0)||p_{\theta}(x_T)\right) \\ L_t &= D_{KL}\left(q(x_t|x_{t + 1}, x_0)||p_{\theta}(x_t|x_{t+1})\right); \quad 1 \leq t \leq T - 1 \\ L_0 &= -\log{p_\theta\left(x_0|x_1\right)} \end{align}
  97. LVLBLTLtL0​​=LT​+LT1​+…+L0​=DKL​(q(xT​∣x0​)∣∣pθ​(xT​))=DKL​(q(xt​∣xt+1​,x0​)∣∣pθ​(xt​∣xt+1​));1tT1=−logpθ​(x0​∣x1​)​​

最终是优化两个高斯分布

  1. q
  2. (
  3. x
  4. t
  5. x
  6. t
  7. 1
  8. ,
  9. x
  10. 0
  11. )
  12. =
  13. N
  14. (
  15. x
  16. t
  17. 1
  18. ;
  19. μ
  20. ~
  21. (
  22. x
  23. t
  24. ,
  25. x
  26. 0
  27. )
  28. ,
  29. β
  30. ~
  31. t
  32. I
  33. )
  34. q(x_t|x_{t - 1}, x_0) = \mathcal{N}\left(x_{t-1} ; {\color{blue}{\tilde{\mu}}(x_t, x_0)}, {\color{red}{\tilde{\beta}_t} \mathbf{I}}\right)
  35. q(xt​∣xt1​,x0​)=N(xt1​;μ~​(xt​,x0​),β~​tI) (详见公式 (7))
  36. p
  37. θ
  38. (
  39. x
  40. t
  41. x
  42. t
  43. +
  44. 1
  45. )
  46. =
  47. N
  48. (
  49. x
  50. t
  51. 1
  52. ;
  53. μ
  54. θ
  55. (
  56. x
  57. t
  58. ,
  59. t
  60. )
  61. ,
  62. Σ
  63. θ
  64. )
  65. p_{\theta}(x_t|x_{t+1}) = \mathcal{N}\left(x_{t-1} ; \mu_\theta\left(x_t, t\right), \Sigma_\theta\right)
  66. pθ​(xt​∣xt+1​)=N(xt1​;μθ​(xt​,t),Σθ​) (详见公式(6), 此为模型预估的分布)之间的 KL 散度. 由于多元高斯分布的 KL 散度存在闭式解, 详见: Multivariate_normal_distributions, 从而可以得到:
  67. L
  68. t
  69. =
  70. E
  71. x
  72. 0
  73. ,
  74. ϵ
  75. [
  76. 1
  77. 2
  78. Σ
  79. θ
  80. (
  81. x
  82. t
  83. ,
  84. t
  85. )
  86. 2
  87. 2
  88. μ
  89. ~
  90. t
  91. (
  92. x
  93. t
  94. ,
  95. x
  96. 0
  97. )
  98. μ
  99. θ
  100. (
  101. x
  102. t
  103. ,
  104. t
  105. )
  106. 2
  107. ]
  108. =
  109. E
  110. x
  111. 0
  112. ,
  113. ϵ
  114. [
  115. 1
  116. 2
  117. Σ
  118. θ
  119. 2
  120. 2
  121. 1
  122. α
  123. t
  124. (
  125. x
  126. t
  127. 1
  128. α
  129. t
  130. 1
  131. α
  132. ˉ
  133. t
  134. ϵ
  135. t
  136. )
  137. 1
  138. α
  139. t
  140. (
  141. x
  142. t
  143. 1
  144. α
  145. t
  146. 1
  147. α
  148. ˉ
  149. t
  150. ϵ
  151. θ
  152. (
  153. x
  154. t
  155. ,
  156. t
  157. )
  158. )
  159. 2
  160. ]
  161. =
  162. E
  163. x
  164. 0
  165. ,
  166. ϵ
  167. [
  168. (
  169. 1
  170. α
  171. t
  172. )
  173. 2
  174. 2
  175. α
  176. t
  177. (
  178. 1
  179. α
  180. ˉ
  181. t
  182. )
  183. Σ
  184. θ
  185. 2
  186. 2
  187. ϵ
  188. t
  189. ϵ
  190. θ
  191. (
  192. x
  193. t
  194. ,
  195. t
  196. )
  197. 2
  198. ]
  199. ;
  200. 其中
  201. ϵ
  202. t
  203. 为高斯噪声
  204. ,
  205. ϵ
  206. θ
  207. 为模型学习的噪声
  208. =
  209. E
  210. x
  211. 0
  212. ,
  213. ϵ
  214. [
  215. (
  216. 1
  217. α
  218. t
  219. )
  220. 2
  221. 2
  222. α
  223. t
  224. (
  225. 1
  226. α
  227. ˉ
  228. t
  229. )
  230. Σ
  231. θ
  232. 2
  233. 2
  234. ϵ
  235. t
  236. ϵ
  237. θ
  238. (
  239. α
  240. ˉ
  241. t
  242. x
  243. 0
  244. +
  245. 1
  246. α
  247. ˉ
  248. t
  249. ϵ
  250. t
  251. ,
  252. t
  253. )
  254. 2
  255. ]
  256. \begin{align} L_t &= \mathbb{E}_{x_0, \epsilon} \Big[\frac{1}{2 \| \boldsymbol{\Sigma}_\theta(x_t, t) \|^2_2} \| \color{blue}{\tilde{\mu}_t(x_t, x_0)} - \color{green}{\mu_\theta(x_t, t)} \|^2 \Big] \\ &= \mathbb{E}_{x_0, \epsilon} \Big[\frac{1}{2 \|\boldsymbol{\Sigma}_\theta \|^2_2} \| \color{blue}{\frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_t \Big)} - \color{green}{\frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta(x_t, t) \Big)} \|^2 \Big] \\ &= \mathbb{E}_{x_0, \epsilon} \Big[\frac{ (1 - \alpha_t)^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \| \boldsymbol{\Sigma}_\theta \|^2_2} \|\epsilon_t - \epsilon_\theta(x_t, t)\|^2 \Big]; \quad \text{其中} \epsilon_t \text{为高斯噪声}, \epsilon_{\theta} \text{为模型学习的噪声} \\ &= \mathbb{E}_{x_0, \epsilon} \Big[\frac{ (1 - \alpha_t)^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \| \boldsymbol{\Sigma}_\theta \|^2_2} \|\epsilon_t - \epsilon_\theta(\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1 - \bar{\alpha}_t}\epsilon_t, t)\|^2 \Big] \end{align}
  257. Lt​​=Ex0​,ϵ​[2∥Σθ​(xt​,t)∥221​∥μ~​t​(xt​,x0​)−μθ​(xt​,t)∥2]=Ex0​,ϵ​[2∥Σθ​∥221​∥αt​​1​(xt​−1−αˉt​​1−αt​​ϵt​)−αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t))∥2]=Ex0​,ϵ​[2αt​(1−αˉt​)∥Σθ​∥22​(1−αt​)2​∥ϵt​−ϵθ​(xt​,t)∥2];其中ϵt​为高斯噪声,ϵθ​为模型学习的噪声=Ex0​,ϵ​[2αt​(1−αˉt​)∥Σθ​∥22​(1−αt​)2​∥ϵt​−ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)∥2]​​

DDPM 将 Loss 简化为如下形式:

  1. L
  2. t
  3. simple
  4. =
  5. E
  6. x
  7. 0
  8. ,
  9. ϵ
  10. t
  11. [
  12. ϵ
  13. t
  14. ϵ
  15. θ
  16. (
  17. α
  18. ˉ
  19. t
  20. x
  21. 0
  22. +
  23. 1
  24. α
  25. ˉ
  26. t
  27. ϵ
  28. t
  29. ,
  30. t
  31. )
  32. 2
  33. ]
  34. \begin{align} L_t^{\text {simple }}=\mathbb{E}_{x_0, \epsilon_t}\left[\left\|\epsilon_t-\epsilon_\theta\left(\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t, t\right)\right\|^2\right] \end{align}
  35. Ltsimple ​=Ex0​,ϵt​​[​ϵt​−ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)​2]​​

因此 Diffusion 模型的目标函数即是学习高斯噪声

  1. ϵ
  2. t
  3. \epsilon_t
  4. ϵt
  5. ϵ
  6. θ
  7. \epsilon_{\theta}
  8. ϵθ​ (来自模型输出) 之间的 MSE loss.

最终算法

最终 DDPM 的算法流程如下:

训练阶段重复如下步骤:

  • 从数据集中采样 x 0 x_0 x0​
  • 随机选取 time step t t t
  • 生成高斯噪声 ϵ t ∈ N ( 0 , I ) \epsilon_t\in\mathcal{N}(0, \mathbf{I}) ϵt​∈N(0,I)
  • 调用模型预估 ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) \epsilon_\theta\left(\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t, t\right) ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)
  • 计算噪声之间的 MSE Loss: ∥ ϵ t − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) ∥ 2 \left|\epsilon_t-\epsilon_\theta\left(\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t, t\right)\right|^2 ​ϵt​−ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)​2, 并利用反向传播算法训练模型.

逆向阶段采用如下步骤进行采样:

  • 从高斯分布采样 x T x_T xT​
  • 按照 T , … , 1 T, \ldots, 1 T,…,1 的顺序进行迭代: - 如果 t = 1 t = 1 t=1, 令 z = 0 \mathbf{z} = {0} z=0; 如果 t > 1 t > 1 t>1, 从高斯分布中采样 z ∼ N ( 0 , I ) \mathbf{z}\sim\mathcal{N}(0, \mathbf{I}) z∼N(0,I)- 利用公式 (12) 学习出均值 μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta(x_t, t) = \color{cyan}{\frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \Big)} μθ​(xt​,t)=αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t)), 并利用公式 (8) 计算均方差 σ t = β ~ t = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t \sigma_t = \sqrt{\tilde{\beta}t} = \sqrt{\frac{1 - \bar{\alpha}{t-1}}{1 - \bar{\alpha}t} \cdot \beta_t} σt​=β~​t​​=1−αˉt​1−αˉt−1​​⋅βt​​- 通过重参数技巧采样 x t − 1 = μ θ ( x t , t ) + σ t z x{t - 1} = \mu_\theta(x_t, t) + \sigma_t\mathbf{z} xt−1​=μθ​(xt​,t)+σt​z
  • 经过以上过程的迭代, 最终恢复 x 0 x_0 x0​.

源码分析

DDPM 文章以及代码的相关信息如下:

本文以分析 Tensorflow 源码为主, Pytorch 版本的代码和 Tensorflow 版本的实现逻辑大体不差的, 变量名字啥的都类似, 阅读起来不会有啥门槛. Tensorlow 源码对 Diffusion 模型的实现位于 diffusion_utils_2.py, 模型本身的分析以该文件为主.

训练阶段

以 CIFAR 数据集为例.

在 run_cifar.py 中进行前向传播计算 Loss:

  • 第 6 行随机选出 t ∼ Uniform ( { 1 , … , T } ) t\sim\text{Uniform}({1, \ldots, T}) t∼Uniform({1,…,T})
  • 第 7 行 training_losses 定义在 GaussianDiffusion2 中, 计算噪声间的 MSE Loss.

进入 GaussianDiffusion2 中, 看到初始化函数中定义了诸多变量, 我在注释中使用公式的方式进行了说明:

下面进入到

  1. training_losses

函数中:

  • 第 19 行: self.model_mean_type 默认是 eps, 模型学习的是噪声, 因此 target 是第 6 行定义的 noise, 即 ϵ t \epsilon_t ϵt​
  • 第 9 行: 调用 self.q_sample 计算 x t x_t xt​, 即公式 (3) x t = α ˉ t x 0 + 1 − α ˉ t ϵ t x_t =\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t xt​=αˉt​​x0​+1−αˉt​​ϵt​
  • 第 21 行: denoise_fn 是定义在 unet.py 中的 UNet 模型, 只需知道它的输入和输出大小相同; 结合第 9 行得到的 x t x_t xt​, 得到模型预估的噪声: ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) \epsilon_\theta\left(\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t, t\right) ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)
  • 第 23 行: 计算两个噪声之间的 MSE: ∥ ϵ t − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) ∥ 2 \left|\epsilon_t-\epsilon_\theta\left(\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t, t\right)\right|^2 ​ϵt​−ϵθ​(αˉt​​x0​+1−αˉt​​ϵt​,t)​2, 并利用反向传播算法训练模型

上面第 9 行定义的

  1. self.q_sample

详情如下:

  • 第 13 行的 q_sample 已经介绍过, 不多说.
  • 第 2 行的 _extract 在代码中经常被使用到, 看到它只需知道它是用来提取系数的即可. 引入输入是一个 Batch, 里面的每个样本都会随机采样一个 time step t t t, 因此需要使用 tf.gather 来将 α t ˉ \bar{\alpha_t} αt​ˉ​ 之类选出来, 然后将系数 reshape 为 [B, 1, 1, ....] 的形式, 目的是为了利用 broadcasting 机制和 x t x_t xt​ 这个 Tensor 相乘.

前向的训练阶段代码实现非常简单, 下面看逆向阶段

逆向阶段

逆向阶段代码定义在 GaussianDiffusion2 中:

  • 第 5 行生成高斯噪声 x T x_T xT​, 然后对其不断去噪直至恢复原始图像
  • 第 11 行的 self.p_sample 就是公式 (6) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta\left(x_{t-1} \mid x_t\right) =\mathcal{N}\left(x_{t-1} ; \mu_\theta\left(x_t, t\right), \Sigma_\theta\left(x_t, t\right)\right) pθ​(xt−1​∣xt​)=N(xt−1​;μθ​(xt​,t),Σθ​(xt​,t)) 的过程, 使用模型来预估 μ θ ( x t , t ) \mu_\theta\left(x_t, t\right) μθ​(xt​,t) 以及 Σ θ ( x t , t ) \Sigma_\theta\left(x_t, t\right) Σθ​(xt​,t)
  • 第 12 行的 denoise_fn 在前面说过, 是定义在 unet.py 中的 UNet 模型; img_ 表示 x t x_t xt​.
  • 第 13 行的 noise_fn 则默认是 tf.random_normal, 用于生成高斯噪声.

进入

  1. p_sample

函数:

  • 第 7 行调用 self.p_mean_variance 生成 μ θ ( x t , t ) \mu_\theta\left(x_t, t\right) μθ​(xt​,t) 以及 log ⁡ ( Σ θ ( x t , t ) ) \log\left(\Sigma_\theta\left(x_t, t\right)\right) log(Σθ​(xt​,t)), 其中 Σ θ ( x t , t ) \Sigma_\theta\left(x_t, t\right) Σθ​(xt​,t) 通过计算 β ~ t \tilde{\beta}_t β~​t​ 得到.
  • 第 11 行从高斯分布中采样 z \mathbf{z} z
  • 第 18 行通过重参数技巧采样 x t − 1 = μ θ ( x t , t ) + σ t z x_{t - 1} = \mu_\theta(x_t, t) + \sigma_t\mathbf{z} xt−1​=μθ​(xt​,t)+σt​z, 其中 σ t = β ~ t \sigma_t = \sqrt{\tilde{\beta}_t} σt​=β~​t​​

进入

  1. self.p_mean_variance

函数:

  • 第 6 行调用模型 denoise_fn, 通过输入 x t x_t xt​, 输出得到噪声 ϵ t \epsilon_t ϵt​
  • 第 19 行 self.model_var_type 默认为 fixedlarge, 但我当时看 fixedsmall 比较爽, 因此 model_variancemodel_log_variance 分别为 β ~ t = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t \tilde{\beta}t = \frac{1 - \bar{\alpha}{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t β​t​=1−αˉt​1−αˉt−1​​⋅βt​ (见公式 8), 以及 log ⁡ β ~ t \log\tilde{\beta}_t logβ​t​
  • 第 29 行调用 self._predict_xstart_from_eps 函数, 利用公式 (10) 得到 x 0 = 1 α ˉ t ( x t − 1 − α ˉ t ϵ t ) x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1 - \bar{\alpha}_t}\epsilon_t) x0​=αˉt​​1​(xt​−1−αˉt​​ϵt​)
  • 第 30 行调用 self.q_posterior_mean_variance 通过公式 (9) 得到 μ θ ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \mu_\theta(x_t, x_0) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}t} x_t + \frac{\sqrt{\bar{\alpha}{t-1}}\beta_t}{1 - \bar{\alpha}_t} x_0 μθ​(xt​,x0​)=1−αˉt​αt​​(1−αˉt−1​)​xt​+1−αˉt​αˉt−1​​βt​​x0​
  1. self._predict_xstart_from_eps

函数相亲如下:

  • 该函数计算 x 0 = 1 α ˉ t ( x t − 1 − α ˉ t ϵ t ) x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1 - \bar{\alpha}_t}\epsilon_t) x0​=αˉt​​1​(xt​−1−αˉt​​ϵt​)
  1. self.q_posterior_mean_variance

函数详情如下:

  • 相关说明见注释, 另外发现对于 μ θ ( x t , x 0 ) \mu_\theta(x_t, x_0) μθ​(xt​,x0​) 的计算使用的是公式 (9) μ θ ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \mu_\theta(x_t, x_0) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}t} x_t + \frac{\sqrt{\bar{\alpha}{t-1}}\beta_t}{1 - \bar{\alpha}_t} x_0 μθ​(xt​,x0​)=1−αˉt​αt​​(1−αˉt−1​)​xt​+1−αˉt​αˉt−1​​βt​​x0​ 而不是进一步推导后的公式 (11) μ θ ( x t , x 0 ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) \mu_\theta(x_t, x_0) = \frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_t \Big) μθ​(xt​,x0​)=αt​​1​(xt​−1−αˉt​​1−αt​​ϵt​).

总结

写文章真的挺累的, 好处是, 我发现写之前我以为理解了, 但写的过程中又发现有些地方理解的不对. 写完后才终于把逻辑理顺.


本文转载自: https://blog.csdn.net/Eric_1993/article/details/127455977
版权归原作者 珍妮的选择 所有, 如有侵权,请联系我们删除。

“扩散模型 (Diffusion Model) 简要介绍与源码分析”的评论:

还没有评论