0


Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成

Diffusion扩散模型学习1——Pytorch搭建DDPM利用深度卷积神经网络实现图片生成

学习前言

我又死了我又死了我又死了!
在这里插入图片描述

源码下载地址

https://github.com/bubbliiiing/ddpm-pytorch

喜欢的可以点个star噢。

网络构建

一、什么是Diffusion

在这里插入图片描述
如上图所示。DDPM模型主要分为两个过程:
1、Forward加噪过程(从右往左),数据集的真实图片中逐步加入高斯噪声,最终变成一个杂乱无章的高斯噪声,这个过程一般发生在训练的时候。加噪过程满足一定的数学规律。
2、Reverse去噪过程(从左往右),指对加了噪声的图片逐步去噪,从而还原出真实图片,这个过程一般发生在预测生成的时候。尽管在这里说的是加了噪声的图片,但实际去预测生成的时候,是随机生成一个高斯噪声来去噪。去噪的时候不断根据

  1. X
  2. t
  3. X_t
  4. Xt​的图片生成
  5. X
  6. t
  7. 1
  8. X_{t-1}
  9. Xt1​的噪声,从而实现图片的还原。

1、加噪过程

在这里插入图片描述
Forward加噪过程主要符合如下的公式:

  1. x
  2. t
  3. =
  4. α
  5. t
  6. x
  7. t
  8. 1
  9. +
  10. 1
  11. α
  12. t
  13. z
  14. 1
  15. x_t=\sqrt{\alpha_t} x_{t-1}+\sqrt{1-\alpha_t} z_{1}
  16. xt​=αt​​xt1​+1−αt​​z1

其中

  1. α
  2. t
  3. \sqrt{\alpha_t}
  4. αt​​是预先设定好的超参数,被称为Noise schedule,通常是小于1的值,在论文中
  5. α
  6. t
  7. \alpha_t
  8. αt​的值从0.99990.998
  9. ϵ
  10. t
  11. 1
  12. N
  13. (
  14. 0
  15. ,
  16. 1
  17. )
  18. \epsilon_{t-1} \sim N(0, 1)
  19. ϵt1​∼N(0,1)是高斯噪声。由公式(1)迭代推导。
  20. x
  21. t
  22. =
  23. a
  24. t
  25. (
  26. a
  27. t
  28. 1
  29. x
  30. t
  31. 2
  32. +
  33. 1
  34. α
  35. t
  36. 1
  37. z
  38. 2
  39. )
  40. +
  41. 1
  42. α
  43. t
  44. z
  45. 1
  46. =
  47. a
  48. t
  49. a
  50. t
  51. 1
  52. x
  53. t
  54. 2
  55. +
  56. (
  57. a
  58. t
  59. (
  60. 1
  61. α
  62. t
  63. 1
  64. )
  65. z
  66. 2
  67. +
  68. 1
  69. α
  70. t
  71. z
  72. 1
  73. )
  74. x_t=\sqrt{a_t}\left(\sqrt{a_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} z_2\right)+\sqrt{1-\alpha_t} z_1=\sqrt{a_t a_{t-1}} x_{t-2}+\left(\sqrt{a_t\left(1-\alpha_{t-1}\right)} z_2+\sqrt{1-\alpha_t} z_1\right)
  75. xt​=at​​(at1​​xt2​+1−αt1​​z2​)+1−αt​​z1​=atat1​​xt2​+(at​(1−αt1​)​z2​+1−αt​​z1​)

其中每次加入的噪声都服从高斯分布

  1. z
  2. 1
  3. ,
  4. z
  5. 2
  6. ,
  7. N
  8. (
  9. 0
  10. ,
  11. 1
  12. )
  13. z_1, z_2, \ldots \sim \mathcal{N}(0, 1)
  14. z1​,z2​,…∼N(0,1),两个高斯分布的相加高斯分布满足公式:
  15. N
  16. (
  17. 0
  18. ,
  19. σ
  20. 1
  21. 2
  22. )
  23. +
  24. N
  25. (
  26. 0
  27. ,
  28. σ
  29. 2
  30. 2
  31. )
  32. N
  33. (
  34. 0
  35. ,
  36. (
  37. σ
  38. 1
  39. 2
  40. +
  41. σ
  42. 2
  43. 2
  44. )
  45. )
  46. \mathcal{N}\left(0, \sigma_1^2 \right)+\mathcal{N}\left(0, \sigma_2^2 \right) \sim \mathcal{N}\left(0,\left(\sigma_1^2+\sigma_2^2\right) \right)
  47. N(012​)+N(022​)∼N(0,(σ12​+σ22​)),因此,得到
  48. x
  49. t
  50. x_t
  51. xt​的公式为:
  52. x
  53. t
  54. =
  55. a
  56. t
  57. a
  58. t
  59. 1
  60. x
  61. t
  62. 2
  63. +
  64. 1
  65. α
  66. t
  67. α
  68. t
  69. 1
  70. z
  71. 2
  72. x_t = \sqrt{a_t a_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} z_2
  73. xt​=atat1​​xt2​+1−αt​αt1​​z2

因此不断往里面套,就能发现规律了,其实就是累乘
可以直接得出

  1. x
  2. 0
  3. x_0
  4. x0​到
  5. x
  6. t
  7. x_t
  8. xt​的公式:
  9. x
  10. t
  11. =
  12. α
  13. t
  14. x
  15. 0
  16. +
  17. 1
  18. α
  19. t
  20. z
  21. t
  22. x_t=\sqrt{\overline{\alpha_t}} x_0+\sqrt{1-\overline{\alpha_t}} z_t
  23. xt​=αt​​​x0​+1−αt​​​zt

其中

  1. α
  2. t
  3. =
  4. i
  5. t
  6. α
  7. i
  8. \overline{\alpha_t}=\prod_i^t \alpha_i
  9. αt​​=∏it​αi​,这是随Noise schedule设定好的超参数,
  10. z
  11. t
  12. 1
  13. N
  14. (
  15. 0
  16. ,
  17. 1
  18. )
  19. z_{t-1} \sim N(0, 1)
  20. zt1​∼N(0,1)也是一个高斯噪声。通过上述两个公式,我们可以不断的将图片进行破坏加噪。

2、去噪过程

在这里插入图片描述
反向过程就是通过估测噪声,多次迭代逐渐将被破坏的

  1. x
  2. t
  3. x_t
  4. xt​恢复成
  5. x
  6. 0
  7. x_0
  8. x0​,在恢复时刻,我们已经知道的是
  9. x
  10. t
  11. x_t
  12. xt​,这是图片在
  13. t
  14. t
  15. t时刻的噪声图。一下子从
  16. x
  17. t
  18. x_t
  19. xt​恢复成
  20. x
  21. 0
  22. x_0
  23. x0​是不可能的,我们只能一步一步的往前推,首先从
  24. x
  25. t
  26. x_t
  27. xt​恢复成
  28. x
  29. t
  30. 1
  31. x_{t-1}
  32. xt1​。根据贝叶斯公式,已知
  33. x
  34. t
  35. x_t
  36. xt​反推
  37. x
  38. t
  39. 1
  40. x_{t-1}
  41. xt1​:
  42. q
  43. (
  44. x
  45. t
  46. 1
  47. x
  48. t
  49. ,
  50. x
  51. 0
  52. )
  53. =
  54. q
  55. (
  56. x
  57. t
  58. x
  59. t
  60. 1
  61. ,
  62. x
  63. 0
  64. )
  65. q
  66. (
  67. x
  68. t
  69. 1
  70. x
  71. 0
  72. )
  73. q
  74. (
  75. x
  76. t
  77. x
  78. 0
  79. )
  80. q\left(x_{t-1} \mid x_t, x_0\right)=q\left(x_t \mid x_{t-1}, x_0\right) \frac{q\left(x_{t-1} \mid x_0\right)}{q\left(x_t \mid x_0\right)}
  81. q(xt1​∣xt​,x0​)=q(xt​∣xt1​,x0​)q(xt​∣x0​)q(xt1​∣x0​)​

右边的三个东西都可以从x_0开始推得到:

  1. q
  2. (
  3. x
  4. t
  5. 1
  6. x
  7. 0
  8. )
  9. =
  10. a
  11. ˉ
  12. t
  13. 1
  14. x
  15. 0
  16. +
  17. 1
  18. a
  19. ˉ
  20. t
  21. 1
  22. z
  23. N
  24. (
  25. a
  26. ˉ
  27. t
  28. 1
  29. x
  30. 0
  31. ,
  32. 1
  33. a
  34. ˉ
  35. t
  36. 1
  37. )
  38. q\left(x_{t-1} \mid x_0\right)=\sqrt{\bar{a}_{t-1}} x_0+\sqrt{1-\bar{a}_{t-1}} z \sim \mathcal{N}\left(\sqrt{\bar{a}_{t-1}} x_0, 1-\bar{a}_{t-1}\right)
  39. q(xt1​∣x0​)=aˉt1​​x0​+1aˉt1​​zN(aˉt1​​x0​,1aˉt1​)
  40. q
  41. (
  42. x
  43. t
  44. x
  45. 0
  46. )
  47. =
  48. a
  49. ˉ
  50. t
  51. x
  52. 0
  53. +
  54. 1
  55. α
  56. ˉ
  57. t
  58. z
  59. N
  60. (
  61. a
  62. ˉ
  63. t
  64. x
  65. 0
  66. ,
  67. 1
  68. α
  69. ˉ
  70. t
  71. )
  72. q\left(x_t \mid x_0\right) = \sqrt{\bar{a}_t} x_0+\sqrt{1-\bar{\alpha}_t} z \sim \mathcal{N}\left(\sqrt{\bar{a}_t} x_0 , 1-\bar{\alpha}_t\right)
  73. q(xt​∣x0​)=aˉt​​x0​+1−αˉt​​zN(aˉt​​x0​,1−αˉt​)
  74. q
  75. (
  76. x
  77. t
  78. x
  79. t
  80. 1
  81. ,
  82. x
  83. 0
  84. )
  85. =
  86. a
  87. t
  88. x
  89. t
  90. 1
  91. +
  92. 1
  93. α
  94. t
  95. z
  96. N
  97. (
  98. a
  99. t
  100. x
  101. t
  102. 1
  103. ,
  104. 1
  105. α
  106. t
  107. )
  108. q\left(x_t \mid x_{t-1}, x_0\right)=\sqrt{a_t} x_{t-1}+\sqrt{1-\alpha_t} z \sim \mathcal{N}\left(\sqrt{a_t} x_{t-1}, 1-\alpha_t\right) \\
  109. q(xt​∣xt1​,x0​)=at​​xt1​+1−αt​​zN(at​​xt1​,1−αt​)

因此,由于右边三个东西均满足正态分布,

  1. q
  2. (
  3. x
  4. t
  5. 1
  6. x
  7. t
  8. ,
  9. x
  10. 0
  11. )
  12. q\left(x_{t-1} \mid x_t, x_0\right)
  13. q(xt1​∣xt​,x0​)满足分布如下:
  14. exp
  15. (
  16. 1
  17. 2
  18. (
  19. (
  20. x
  21. t
  22. α
  23. t
  24. x
  25. t
  26. 1
  27. )
  28. 2
  29. β
  30. t
  31. +
  32. (
  33. x
  34. t
  35. 1
  36. α
  37. ˉ
  38. t
  39. 1
  40. x
  41. 0
  42. )
  43. 2
  44. 1
  45. α
  46. ˉ
  47. t
  48. 1
  49. (
  50. x
  51. t
  52. α
  53. ˉ
  54. t
  55. x
  56. 0
  57. )
  58. 2
  59. 1
  60. α
  61. ˉ
  62. t
  63. )
  64. )
  65. \propto \exp \left(-\frac{1}{2}\left(\frac{\left(x_t-\sqrt{\alpha_t} x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}} x_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right)
  66. exp(−21​(βt​(xt​−αt​​xt1​)2​+1−αˉt1​(xt1​−αˉt1​​x0​)2​−1−αˉt​(xt​−αˉt​​x0​)2​))

把标准正态分布展开后,乘法就相当于加,除法就相当于减,把他们汇总
接下来继续化简,咱们现在要求的是上一时刻的分布

  1. exp
  2. (
  3. 1
  4. 2
  5. (
  6. (
  7. x
  8. t
  9. α
  10. t
  11. x
  12. t
  13. 1
  14. )
  15. 2
  16. β
  17. t
  18. +
  19. (
  20. x
  21. t
  22. 1
  23. α
  24. ˉ
  25. t
  26. 1
  27. x
  28. 0
  29. )
  30. 2
  31. 1
  32. α
  33. ˉ
  34. t
  35. 1
  36. (
  37. x
  38. t
  39. α
  40. ˉ
  41. t
  42. x
  43. 0
  44. )
  45. 2
  46. 1
  47. α
  48. ˉ
  49. t
  50. )
  51. )
  52. =
  53. exp
  54. (
  55. 1
  56. 2
  57. (
  58. x
  59. t
  60. 2
  61. 2
  62. α
  63. t
  64. x
  65. t
  66. x
  67. t
  68. 1
  69. +
  70. α
  71. t
  72. x
  73. t
  74. 1
  75. 2
  76. β
  77. t
  78. +
  79. x
  80. t
  81. 1
  82. 2
  83. 2
  84. α
  85. ˉ
  86. t
  87. 1
  88. x
  89. 0
  90. x
  91. t
  92. 1
  93. +
  94. α
  95. ˉ
  96. t
  97. 1
  98. x
  99. 0
  100. 2
  101. 1
  102. α
  103. ˉ
  104. t
  105. 1
  106. (
  107. x
  108. t
  109. α
  110. ˉ
  111. t
  112. x
  113. 0
  114. )
  115. 2
  116. 1
  117. α
  118. ˉ
  119. t
  120. )
  121. )
  122. =
  123. exp
  124. (
  125. 1
  126. 2
  127. (
  128. (
  129. α
  130. t
  131. β
  132. t
  133. +
  134. 1
  135. 1
  136. α
  137. ˉ
  138. t
  139. 1
  140. )
  141. x
  142. t
  143. 1
  144. 2
  145. (
  146. 2
  147. α
  148. t
  149. β
  150. t
  151. x
  152. t
  153. +
  154. 2
  155. α
  156. ˉ
  157. t
  158. 1
  159. 1
  160. α
  161. ˉ
  162. t
  163. 1
  164. x
  165. 0
  166. )
  167. x
  168. t
  169. 1
  170. +
  171. C
  172. (
  173. x
  174. t
  175. ,
  176. x
  177. 0
  178. )
  179. )
  180. )
  181. \begin{aligned} & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(x_t-\sqrt{\alpha_t} x_{t-1}\right)^2}{\beta_t}+\frac{\left(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}} x_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\frac{x_t^2-2 \sqrt{\alpha_t} x_t x_{t-1}+\alpha_t x_{t-1}^2}{\beta_t}+\frac{x_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} x_0 x_{t-1}+\bar{\alpha}_{t-1} x_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(x_t-\sqrt{\bar{\alpha}_t} x_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) x_{t-1}^2-\left(\frac{2 \sqrt{\alpha_t}}{\beta_t} x_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} x_0\right) x_{t-1}+C\left(x_t, x_0\right)\right)\right) \end{aligned}
  182. ​∝exp(−21​(βt​(xt​−αt​​xt1​)2​+1−αˉt1​(xt1​−αˉt1​​x0​)2​−1−αˉt​(xt​−αˉt​​x0​)2​))=exp(−21​(βtxt2​−2αt​​xtxt1​+αtxt12​​+1−αˉt1xt12​−2αˉt1​​x0xt1​+αˉt1x02​​−1−αˉt​(xt​−αˉt​​x0​)2​))=exp(−21​((βt​αt​​+1−αˉt11​)xt12​−(βt2αt​​​xt​+1−αˉt12αˉt1​​​x0​)xt1​+C(xt​,x0​)))​

正态分布满足公式,

  1. exp
  2. (
  3. (
  4. x
  5. μ
  6. )
  7. 2
  8. 2
  9. σ
  10. 2
  11. )
  12. =
  13. exp
  14. (
  15. 1
  16. 2
  17. (
  18. 1
  19. σ
  20. 2
  21. x
  22. 2
  23. 2
  24. μ
  25. σ
  26. 2
  27. x
  28. +
  29. μ
  30. 2
  31. σ
  32. 2
  33. )
  34. )
  35. \exp \left(-\frac{(x-\mu)^2}{2 \sigma^2}\right)=\exp \left(-\frac{1}{2}\left(\frac{1}{\sigma^2} x^2-\frac{2 \mu}{\sigma^2} x+\frac{\mu^2}{\sigma^2}\right)\right)
  36. exp(−2σ2(x−μ)2​)=exp(−21​(σ21x2−σ22μ​x2μ2​)),其中
  37. σ
  38. \sigma
  39. σ就是方差,
  40. μ
  41. \mu
  42. μ就是均值,配方后我们就可以获得均值和方差。

此时的均值为:

  1. μ
  2. ~
  3. t
  4. (
  5. x
  6. t
  7. ,
  8. x
  9. 0
  10. )
  11. =
  12. α
  13. t
  14. (
  15. 1
  16. α
  17. ˉ
  18. t
  19. 1
  20. )
  21. 1
  22. α
  23. ˉ
  24. t
  25. x
  26. t
  27. +
  28. α
  29. ˉ
  30. t
  31. 1
  32. β
  33. t
  34. 1
  35. α
  36. ˉ
  37. t
  38. x
  39. 0
  40. \tilde{\mu}_t\left(x_t, x_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} x_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} x_0
  41. μ~​t​(xt​,x0​)=1−αˉt​αt​​(1−αˉt1​)​xt​+1−αˉt​αˉt1​​βt​​x0​。根据之前的公式,
  42. x
  43. t
  44. =
  45. α
  46. t
  47. x
  48. 0
  49. +
  50. 1
  51. α
  52. t
  53. z
  54. t
  55. x_t=\sqrt{\overline{\alpha_t}} x_0+\sqrt{1-\overline{\alpha_t}} z_t
  56. xt​=αt​​​x0​+1−αt​​​zt​,我们可以使用
  57. x
  58. t
  59. x_t
  60. xt​反向估计
  61. x
  62. 0
  63. x_0
  64. x0​得到
  65. x
  66. 0
  67. x_0
  68. x0​满足分布
  69. x
  70. 0
  71. =
  72. 1
  73. α
  74. ˉ
  75. t
  76. (
  77. x
  78. t
  79. 1
  80. α
  81. ˉ
  82. t
  83. z
  84. t
  85. )
  86. x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathrm{x}_t-\sqrt{1-\bar{\alpha}_t} z_t\right)
  87. x0​=αˉt​​1​(xt​−1−αˉt​​zt​)。最终得到均值为
  88. μ
  89. ~
  90. t
  91. =
  92. 1
  93. a
  94. t
  95. (
  96. x
  97. t
  98. β
  99. t
  100. 1
  101. a
  102. ˉ
  103. t
  104. z
  105. t
  106. )
  107. \tilde{\mu}_t=\frac{1}{\sqrt{a_t}}\left(x_t-\frac{\beta_t}{\sqrt{1-\bar{a}_t}} z_t\right)
  108. μ~​t​=at​​1​(xt​−1aˉt​​βt​​zt​)
  109. z
  110. t
  111. z_t
  112. zt​代表t时刻的噪音是什么。由
  113. z
  114. t
  115. z_t
  116. zt​无法直接获得,网络便通过当前时刻的
  117. x
  118. t
  119. x_t
  120. xt​经过神经网络计算
  121. z
  122. t
  123. z_t
  124. zt​。
  125. ϵ
  126. θ
  127. (
  128. x
  129. t
  130. ,
  131. t
  132. )
  133. \epsilon_\theta\left(x_t, t\right)
  134. ϵθ​(xt​,t)也就是上面提到的
  135. z
  136. t
  137. z_t
  138. zt​。
  139. ϵ
  140. θ
  141. \epsilon_\theta
  142. ϵθ​代表神经网络。
  143. x
  144. t
  145. 1
  146. =
  147. 1
  148. α
  149. t
  150. (
  151. x
  152. t
  153. 1
  154. α
  155. t
  156. 1
  157. α
  158. ˉ
  159. t
  160. ϵ
  161. θ
  162. (
  163. x
  164. t
  165. ,
  166. t
  167. )
  168. )
  169. +
  170. σ
  171. t
  172. z
  173. x_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta\left(x_t, t\right)\right)+\sigma_t z
  174. xt1​=αt​​1​(xt​−1−αˉt​​1−αt​​ϵθ​(xt​,t))+σtz

由于加噪过程中的真实噪声

  1. ϵ
  2. \epsilon
  3. ϵ在复原过程中是无法获得的,因此DDPM的关键就是训练一个由
  4. x
  5. t
  6. x_t
  7. xt​和
  8. t
  9. t
  10. t估测橾声的模型
  11. ϵ
  12. θ
  13. (
  14. x
  15. t
  16. ,
  17. t
  18. )
  19. \epsilon_\theta\left(x_t, t\right)
  20. ϵθ​(xt​,t),其中
  21. θ
  22. \theta
  23. θ就是模型的训练参数,
  24. σ
  25. t
  26. \sigma_t
  27. σt 也是一个高斯噪声
  28. σ
  29. t
  30. N
  31. (
  32. 0
  33. ,
  34. 1
  35. )
  36. \sigma_t \sim N(0,1)
  37. σt​∼N(0,1),用于表示估测与实际的差距。在DDPM中,使用U-Net作为估测噪声的模型。

本质上,我们就是训练这个Unet模型,该模型输入为

  1. x
  2. t
  3. x_t
  4. xt​和
  5. t
  6. t
  7. t,输出为
  8. x
  9. t
  10. x_t
  11. xt​时刻的**高斯噪声**。即利用
  12. x
  13. t
  14. x_t
  15. xt​和
  16. t
  17. t
  18. t预测这一时刻的高斯噪声。**这样就可以一步一步的再从噪声回到真实图像。**

二、DDPM网络的构建(Unet网络的构建)

在这里插入图片描述
上图是典型的Unet模型结构,仅仅作为示意图,里面具体的数字同学们无需在意,和本文的学习无关。在本文中,Unet的输入和输出shape相同,通道均为3(一般为RGB三通道),宽高相同。

本质上,DDPM最重要的工作就是训练Unet模型,该模型输入为

  1. x
  2. t
  3. x_t
  4. xt​和
  5. t
  6. t
  7. t,输出为
  8. x
  9. t
  10. 1
  11. x_{t-1}
  12. xt1​时刻的**高斯噪声**。即利用
  13. x
  14. t
  15. x_t
  16. xt​和
  17. t
  18. t
  19. t预测上一时刻的高斯噪声。**这样就可以一步一步的再从噪声回到真实图像。**

假设我们需要生成一个[64, 64, 3]的图像,在

  1. t
  2. t
  3. t时刻,我们有一个
  4. x
  5. t
  6. x_t
  7. xt​噪声图,该噪声图的的shape也为[64, 64, 3],我们将它和
  8. t
  9. t
  10. t一起输入到Unet中。Unet的输出为
  11. x
  12. t
  13. 1
  14. x_{t-1}
  15. xt1​时刻的[64, 64, 3]的噪声。

实现代码如下,代码中的特征提取模块为残差结构,方便优化:

  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. defget_norm(norm, num_channels, num_groups):if norm =="in":return nn.InstanceNorm2d(num_channels, affine=True)elif norm =="bn":return nn.BatchNorm2d(num_channels)elif norm =="gn":return nn.GroupNorm(num_groups, num_channels)elif norm isNone:return nn.Identity()else:raise ValueError("unknown normalization type")#------------------------------------------## 计算时间步长的位置嵌入。# 一半为sin,一半为cos。#------------------------------------------#classPositionalEmbedding(nn.Module):def__init__(self, dim, scale=1.0):super().__init__()assert dim %2==0
  6. self.dim = dim
  7. self.scale = scale
  8. defforward(self, x):
  9. device = x.device
  10. half_dim = self.dim //2
  11. emb = math.log(10000)/ half_dim
  12. emb = torch.exp(torch.arange(half_dim, device=device)*-emb)# x * self.scaleemb外积
  13. emb = torch.outer(x * self.scale, emb)
  14. emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb
  15. #------------------------------------------## 下采样层,一个步长为2x2的卷积#------------------------------------------#classDownsample(nn.Module):def__init__(self, in_channels):super().__init__()
  16. self.downsample = nn.Conv2d(in_channels, in_channels,3, stride=2, padding=1)defforward(self, x, time_emb, y):if x.shape[2]%2==1:raise ValueError("downsampling tensor height should be even")if x.shape[3]%2==1:raise ValueError("downsampling tensor width should be even")return self.downsample(x)#------------------------------------------## 上采样层,Upsample+卷积#------------------------------------------#classUpsample(nn.Module):def__init__(self, in_channels):super().__init__()
  17. self.upsample = nn.Sequential(
  18. nn.Upsample(scale_factor=2, mode="nearest"),
  19. nn.Conv2d(in_channels, in_channels,3, padding=1),)defforward(self, x, time_emb, y):return self.upsample(x)#------------------------------------------## 使用Self-Attention注意力机制# 做一个全局的Self-Attention#------------------------------------------#classAttentionBlock(nn.Module):def__init__(self, in_channels, norm="gn", num_groups=32):super().__init__()
  20. self.in_channels = in_channels
  21. self.norm = get_norm(norm, in_channels, num_groups)
  22. self.to_qkv = nn.Conv2d(in_channels, in_channels *3,1)
  23. self.to_out = nn.Conv2d(in_channels, in_channels,1)defforward(self, x):
  24. b, c, h, w = x.shape
  25. q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
  26. q = q.permute(0,2,3,1).view(b, h * w, c)
  27. k = k.view(b, c, h * w)
  28. v = v.permute(0,2,3,1).view(b, h * w, c)
  29. dot_products = torch.bmm(q, k)*(c **(-0.5))assert dot_products.shape ==(b, h * w, h * w)
  30. attention = torch.softmax(dot_products, dim=-1)
  31. out = torch.bmm(attention, v)assert out.shape ==(b, h * w, c)
  32. out = out.view(b, h, w, c).permute(0,3,1,2)return self.to_out(out)+ x
  33. #------------------------------------------## 用于特征提取的残差结构#------------------------------------------#classResidualBlock(nn.Module):def__init__(
  34. self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,
  35. norm="gn", num_groups=32, use_attention=False,):super().__init__()
  36. self.activation = activation
  37. self.norm_1 = get_norm(norm, in_channels, num_groups)
  38. self.conv_1 = nn.Conv2d(in_channels, out_channels,3, padding=1)
  39. self.norm_2 = get_norm(norm, out_channels, num_groups)
  40. self.conv_2 = nn.Sequential(
  41. nn.Dropout(p=dropout),
  42. nn.Conv2d(out_channels, out_channels,3, padding=1),)
  43. self.time_bias = nn.Linear(time_emb_dim, out_channels)if time_emb_dim isnotNoneelseNone
  44. self.class_bias = nn.Embedding(num_classes, out_channels)if num_classes isnotNoneelseNone
  45. self.residual_connection = nn.Conv2d(in_channels, out_channels,1)if in_channels != out_channels else nn.Identity()
  46. self.attention = nn.Identity()ifnot use_attention else AttentionBlock(out_channels, norm, num_groups)defforward(self, x, time_emb=None, y=None):
  47. out = self.activation(self.norm_1(x))# 第一个卷积
  48. out = self.conv_1(out)# 对时间time_emb做一个全连接,施加在通道上if self.time_bias isnotNone:if time_emb isNone:raise ValueError("time conditioning was specified but time_emb is not passed")
  49. out += self.time_bias(self.activation(time_emb))[:,:,None,None]# 对种类y_emb做一个全连接,施加在通道上if self.class_bias isnotNone:if y isNone:raise ValueError("class conditioning was specified but y is not passed")
  50. out += self.class_bias(y)[:,:,None,None]
  51. out = self.activation(self.norm_2(out))# 第二个卷积+残差边
  52. out = self.conv_2(out)+ self.residual_connection(x)# 最后做个Attention
  53. out = self.attention(out)return out
  54. #------------------------------------------## Unet模型#------------------------------------------#classUNet(nn.Module):def__init__(
  55. self, img_channels, base_channels=128, channel_mults=(1,2,2,2),
  56. num_res_blocks=2, time_emb_dim=128*4, time_emb_scale=1.0, num_classes=None, activation=F.silu,
  57. dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,):super().__init__()# 使用到的激活函数,一般为SILU
  58. self.activation = activation
  59. # 是否对输入进行padding
  60. self.initial_pad = initial_pad
  61. # 需要去区分的类别数
  62. self.num_classes = num_classes
  63. # 对时间轴输入的全连接层
  64. self.time_mlp = nn.Sequential(
  65. PositionalEmbedding(base_channels, time_emb_scale),
  66. nn.Linear(base_channels, time_emb_dim),
  67. nn.SiLU(),
  68. nn.Linear(time_emb_dim, time_emb_dim),)if time_emb_dim isnotNoneelseNone# 对输入图片的第一个卷积
  69. self.init_conv = nn.Conv2d(img_channels, base_channels,3, padding=1)# self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征# 然后利用Downsample降低特征图的高宽
  70. self.downs = nn.ModuleList()
  71. self.ups = nn.ModuleList()# channels指的是每一个模块处理后的通道数# now_channels是一个中间变量,代表中间的通道数
  72. channels =[base_channels]
  73. now_channels = base_channels
  74. for i, mult inenumerate(channel_mults):
  75. out_channels = base_channels * mult
  76. for _ inrange(num_res_blocks):
  77. self.downs.append(
  78. ResidualBlock(
  79. now_channels, out_channels, dropout,
  80. time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
  81. norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))
  82. now_channels = out_channels
  83. channels.append(now_channels)if i !=len(channel_mults)-1:
  84. self.downs.append(Downsample(now_channels))
  85. channels.append(now_channels)# 可以看作是特征整合,中间的一个特征提取模块
  86. self.mid = nn.ModuleList([
  87. ResidualBlock(
  88. now_channels, now_channels, dropout,
  89. time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
  90. norm=norm, num_groups=num_groups, use_attention=True,),
  91. ResidualBlock(
  92. now_channels, now_channels, dropout,
  93. time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
  94. norm=norm, num_groups=num_groups, use_attention=False,),])# 进行上采样,进行特征融合for i, mult inreversed(list(enumerate(channel_mults))):
  95. out_channels = base_channels * mult
  96. for _ inrange(num_res_blocks +1):
  97. self.ups.append(ResidualBlock(
  98. channels.pop()+ now_channels, out_channels, dropout,
  99. time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
  100. norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))
  101. now_channels = out_channels
  102. if i !=0:
  103. self.ups.append(Upsample(now_channels))assertlen(channels)==0
  104. self.out_norm = get_norm(norm, base_channels, num_groups)
  105. self.out_conv = nn.Conv2d(base_channels, img_channels,3, padding=1)defforward(self, x, time=None, y=None):# 是否对输入进行padding
  106. ip = self.initial_pad
  107. if ip !=0:
  108. x = F.pad(x,(ip,)*4)# 对时间轴输入的全连接层if self.time_mlp isnotNone:if time isNone:raise ValueError("time conditioning was specified but tim is not passed")
  109. time_emb = self.time_mlp(time)else:
  110. time_emb =Noneif self.num_classes isnotNoneand y isNone:raise ValueError("class conditioning was specified but y is not passed")# 对输入图片的第一个卷积
  111. x = self.init_conv(x)# skips用于存放下采样的中间层
  112. skips =[x]for layer in self.downs:
  113. x = layer(x, time_emb, y)
  114. skips.append(x)# 特征整合与提取for layer in self.mid:
  115. x = layer(x, time_emb, y)# 上采样并进行特征融合for layer in self.ups:ifisinstance(layer, ResidualBlock):
  116. x = torch.cat([x, skips.pop()], dim=1)
  117. x = layer(x, time_emb, y)# 上采样并进行特征融合
  118. x = self.activation(self.out_norm(x))
  119. x = self.out_conv(x)if self.initial_pad !=0:return x[:,:, ip:-ip, ip:-ip]else:return x

三、Diffusion的训练思路

Diffusion的训练思路比较简单,首先随机给每个batch里每张图片都生成一个t,代表我选择这个batch里面第t个时刻的噪声进行拟合。代码如下:

  1. t = torch.randint(0, self.num_timesteps,(b,), device=device)

生成batch_size个噪声,计算施加这个噪声后模型在t个时刻的噪声图片是怎么样的,如下所示:

  1. defperturb_x(self, x, t, noise):return(
  2. extract(self.sqrt_alphas_cumprod, t, x.shape)* x +
  3. extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)* noise
  4. )defget_losses(self, x, t, y):# x, noise [batch_size, 3, 64, 64]
  5. noise = torch.randn_like(x)
  6. perturbed_x = self.perturb_x(x, t, noise)

之后利用这个噪声图片、t和网络模型计算预测噪声,利用预测噪声和实际噪声进行拟合。

  1. defget_losses(self, x, t, y):# x, noise [batch_size, 3, 64, 64]
  2. noise = torch.randn_like(x)
  3. perturbed_x = self.perturb_x(x, t, noise)
  4. estimated_noise = self.model(perturbed_x, t, y)if self.loss_type =="l1":
  5. loss = F.l1_loss(estimated_noise, noise)elif self.loss_type =="l2":
  6. loss = F.mse_loss(estimated_noise, noise)return loss

利用DDPM生成图片

DDPM的库整体结构如下:
在这里插入图片描述

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。
在这里插入图片描述

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。
在这里插入图片描述

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。
在这里插入图片描述
训练过程中,可在results文件夹内查看训练效果:
在这里插入图片描述


本文转载自: https://blog.csdn.net/weixin_44791964/article/details/128604816
版权归原作者 Bubbliiiing 所有, 如有侵权,请联系我们删除。

“Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成”的评论:

还没有评论