0


PyTorch中计算KL散度详解

PyTorch计算KL散度详解

最近在进行方法设计时,需要度量分布之间的差异,由于样本间分布具有相似性,首先想到了便于实现的KL-Divergence,使用PyTorch中的内置方法时,踩了不少坑,在这里详细记录一下。

简介

首先简单介绍一下KL散度(具体的可以在各种技术博客看到讲解,我这里不做重点讨论)。
从名称可以看出来,它并不是严格意义上的距离(所以才叫做散度~),原因是它并不满足距离的对称性,为了弥补这种缺陷,出现了JS散度(这就是另一个故事了…)
我们先来看一下KL散度的形式:

  1. D
  2. K
  3. L
  4. (
  5. P
  6. Q
  7. )
  8. =
  9. i
  10. =
  11. 1
  12. N
  13. p
  14. i
  15. log
  16. p
  17. i
  18. q
  19. i
  20. =
  21. i
  22. =
  23. 1
  24. N
  25. p
  26. i
  27. (
  28. log
  29. p
  30. i
  31. log
  32. q
  33. i
  34. )
  35. DKL(P||Q) = \sum_{i=1}^{N} {p_i\log{\frac{p_i}{q_i}}} = \sum_{i=1}^{N} { p_i*(\log{p_i}-\log{q_i})}
  36. DKL(P∣∣Q)=i=1Npilogqipi​​=i=1Npi​∗(logpi​−logqi​)

手动代码实现

可以看到,KL散度形式上还是比较直观的,我们先手撸一个试试:
这里我们随机设定两个随机变量P和Q

  1. import torch
  2. P = torch.tensor([0.4,0.6])
  3. Q = torch.tensor([0.3,0.7])

快速算一下答案:

  1. D
  2. K
  3. L
  4. (
  5. P
  6. Q
  7. )
  8. =
  9. 0.4
  10. (
  11. log
  12. 0.4
  13. log
  14. 0.3
  15. )
  16. +
  17. 0.6
  18. (
  19. log
  20. 0.6
  21. log
  22. 0.7
  23. )
  24. 0.0226
  25. \begin{aligned} DKL(P||Q) &= 0.4* (\log{0.4} - \log{0.3}) + 0.6 * (\log{0.6} - \log{0.7}) \\ & \approx 0.0226 \end{aligned}
  26. DKL(P∣∣Q)​=0.4∗(log0.4log0.3)+0.6∗(log0.6log0.7)≈0.0226

数值计算实现版:

  1. defDKL(_p, _q):"""calculate the KL divergence between _p and _q
  2. """return torch.sum(_p *(_p.log()- _q.log()), dim=-1)
  3. divergence = DKL(P, Q)print(divergence)# tensor(0.0226)

上面的代码中,之所以求和时

  1. dim=-1

是因为我在使用的过程中,考虑到有时是对batch中feature进行计算,所以这里只对特征维度进行求和。
接下来,就到了今天介绍的主角~

torch代码实现

torch中提供有两种不同的api用于计算KL散度,分别是

  1. torch.nn.functional.kl_div()

  1. torch.nn.KLDivLoss()

,两者计算效果类似,区别无非是直接计算和作为损失函数类。

先介绍一下

  1. torch.nn.functional.kl_div()

注意,该方法的

  1. input

  1. target

  1. K
  2. L
  3. (
  4. P
  5. Q
  6. )
  7. KL(P||Q)
  8. KL(P∣∣Q)中
  9. P
  10. P
  11. P
  12. Q
  13. Q
  14. Q的位置正好相反,从参数名称就可以看出来(
  1. target

为目标分布

  1. P
  2. P
  3. P
  1. input

为待度量分布

  1. Q
  2. Q
  3. Q)。为了防止指代混乱,我后面统一用
  4. P
  5. P
  6. P
  7. Q
  8. Q
  9. Q指代
  1. target

  1. input


在这里插入图片描述
这里重点关注几个对计算结果有影响的参数:

  1. reduction

:该参数是结果应该以什么规约形式进行呈现,

  1. sum

即为我们定义式中的效果,

  1. batchmean

:按照batch大小求平均,

  1. mean

:按照元素个数进行求平均

再看看

  1. log_target

的效果:

  1. ifnot log_target:# default
  2. loss_pointwise = target *(target.log()-input)else:
  3. loss_pointwise = target.exp()*(target -input)

也就是说,如果

  1. log_target=False

,此时计算方式为

  1. r
  2. e
  3. s
  4. =
  5. P
  6. (
  7. log
  8. P
  9. Q
  10. )
  11. res = P * ( \log{P}-Q)
  12. res=P∗(logPQ)

这和我们熟悉的定义式的计算方式是不同的,如果想要和定义式的效果一致,需要对

  1. input

取对数操作(在官方文档中也有提及,建议将

  1. input

映射到对数空间,防止数值下溢):

  1. import torch.nn.Functional as F
  2. print(F.kl_div(Q.log(), P, reduction='sum'))#tensor(0.0226)

而当

  1. log_target=True

时,此时的计算方式变为

  1. r
  2. e
  3. s
  4. =
  5. e
  6. P
  7. (
  8. P
  9. Q
  10. )
  11. res=e^{P}*(P-Q)
  12. res=eP∗(PQ)

也就是说,此时我们对

  1. P
  2. P
  3. P取对数操作即可得到定义式的效果:
  1. print(F.kl_div(Q.log(), P.log(),
  2. log_target=True, reduction='sum'))#tensor(0.0226)

这样设计的目的也是为了防止数值下溢。

  1. torch.nn.KLDivLoss()

的参数列表与

  1. torch.nn.functional.kl_div()

类似,这里就不过多赘述。

总结

总的来说,当需要计算KL散度时,默认情况下需要对

  1. input

取对数,并设置

  1. reduction='sum'

方能得到与定义式相同的结果:

  1. divergence = F.kl_div(Q.log(), P, reduction='sum')

由于我们度量的是两个分布的差异,因此通常需要对输入进行softmax归一化(如果已经归一化则无需此操作):

  1. divergence = F.kl_div(Q.softmax(-1).log(), P.softmax(-1), reduction='sum')

本文转载自: https://blog.csdn.net/qq_45802280/article/details/127990600
版权归原作者 __init__: 所有, 如有侵权,请联系我们删除。

“PyTorch中计算KL散度详解”的评论:

还没有评论