扩散模型在生成高质量图像领域具有显著优势,但其迭代去噪过程导致计算开销较大。分布匹配蒸馏(Distribution Matching Distillation,DMD)通过将多步扩散过程精简为单步生成器来解决这一问题。该方法结合分布匹配损失函数和对抗生成网络损失,实现从噪声图像到真实图像的高效映射,为快速图像生成应用提供了新的技术路径。
分布匹配机制
与传统扩散模型不同,单步生成器并不直接学习完整的数据分布,而是通过强制对齐的方式逼近目标分布。这种方法摒弃了逐步近似的过程,直接建立噪声样本到目标分布的映射关系。
在此过程中,蒸馏机制起到关键作用。预训练模型作为教师网络,提供目标分布的高精度中间表征。
DMD 技术实现流程
阶段 0:系统初始化
- 单步生成器基于预训练扩散 unet 进行初始化,时间步设定为 T-1
- real_unet 作为固定权重的教师网络,表征真实数据分布
- fake_unet 用于对生成器的数据分布进行建模
阶段 1:噪声到图像的生成
生成器接收随机噪声图作为输入,通过单步去噪操作生成图像 x,此时生成的图像 x 符合生成器的概率密度分布 p_fake
阶段 2:高斯噪声注入
对生成图像 x 施加高斯噪声,获得噪声图像 xt,在 0.2T 到 0.98T 范围内均匀采样时间步 t(避开极端噪声状态),噪声注入操作促进 p_fake 与 p_real 分布的重叠,为后续分布比较创造条件
阶段 3:双重网络处理
- real_unet 生成 pred_real_image,作为清晰图像的参考近似
- fake_unet 生成 pred_fake_image,反映当前时间步的生成器分布特征
通过对比 pred_real_image 和 pred_fake_image 量化真实分布与生成分布的差异
阶段 4:损失计算
计算 x 与 x — grad 之间的均方误差(MSE)作为损失度量。其中 x — grad 表示经过梯度校正的输出,用于减小与真实数据分布的偏差。
阶段 5:假分布更新机制
fake_unet 通过 x 和 pred_fake_image 之间的扩散损失进行参数更新。这一过程使 fake unet 能够追踪生成器分布的动态变化。与传统 unet 使用 xt-1_pred 和 xt-1_gt 计算损失不同,这里采用 xt-1_pred 和 x 之间的损失,使 fake UNet 能够将生成器输出的噪声版本(xt)还原为当前生成器输出 x。
核心问题解析
问题 1: 为何 fake_unet 采用 xt-1_pred 和 x0 之间的散度作为损失度量,而非采用 xt-1_pred 和 xt-1_gt 的比较?
选择 xt-1_pred 和 x 之间的散度是基于 fake_unet 的核心功能考虑。其目标是将生成器输出的噪声版本(xt)映射回生成器的当前输出(x)。这种设计确保了 fake_unet 能够准确捕获生成器的动态分布特征,从而提供有效的梯度信息来优化生成器输出。
问题 2:fake_unet 的必要性何在?是否可以直接利用预训练的 real_unet 输出与生成器输出计算 KL 散度?
生成器的设计目标是实现单步完全去噪,而预训练的 real_unet 在相同时间步内仅能实现部分去噪。这种本质差异导致 real_unet 输出无法提供有效的 KL 散度用于生成器训练。相比之下,fake_unet 通过持续学习生成器的动态分布,能够准确approximation当前生成器输出的特征。通过比较 real_unet 和 fake_unet 的输出,可以获得用于优化生成器概率分布的有效梯度方向,从而提升单步图像合成的质量。# 分布匹配损失机制
训练过程中,通过 KL 散度定量评估生成器分布与真实分布之间的差异。
其中 Preal 代表真实数据的概率密度函数,Pfake 表示生成器 Gθ 产生的假分布概率密度函数。
对于高维数据集,直接计算概率密度在计算复杂度上存在显著挑战。例如,对于 32×32 像素的灰度图像,其维度空间为 256¹⁰²⁴,直接计算在实际应用中不可行。
因此,采用分数函数对真实分布和生成分布进行特征表征。
这种方法使得 KL 散度的计算成为可能:Sreal 引导 x 向 Preal 的模态靠近,而 −Sfake 则促使其远离真实分布。
其中 Sreal(x) 为真实数据分布的分数函数,Sfake(x) 为生成数据分布的分数函数,∇θ Gθ(z) 表示生成器输出 x 对参数的梯度。
Sreal(x)−Sfake(x) 表征了真实分数与生成分数的差异。对于生成样本 x,由于其 Sreal 接近零,需要引入扰动以支持扩散模型从 xt 进行去噪。
Sfake 和 Sreal 的定义参考自论文 "Song et al. — Score-based generative modeling through stochastic differential equations"
最终损失函数
技术原理剖析
在时间步 t−1,利用 real_unet 和 fake_unet 的输出构建梯度,引导生成器的当前输出 x 向 real_unet 在 t=0 时刻的输出收敛。随后计算生成器原始输出与梯度校正后输出的均方误差(MSE)。这一校正机制确保 x 能够逐步对齐真实数据分布。
损失函数的代码实现
该图展示了不同时间步的损失函数变化,详细说明了多步生成器对单步生成器的训练过程。注意: 图中未详细展示 weighting_factor 相关细节,并对底层分布作出了特定假设。
核心思想在于利用 xfake 和 xreal 之间的差异产生的梯度,将生成器输出引导至 real_unet 在 t=0 时刻的目标输出。随着训练进行,生成器输出逐步向真实分布靠近,同时带动 fake_unet 输出的优化。最终,校正后的图像 ∥x−grad∥ 收敛至真实分布。
总结
本文深入探讨了分布匹配蒸馏(DMD)的技术原理和实现机制,着重阐述了其在图像生成领域的应用价值。欢迎学术界同仁就相关技术细节提供建议和讨论,以促进该领域的持续发展。
作者:Om Rastogi