0


Gumbel-Softmax完全解析

写在前面

本文对大部分人来说可能仅仅起到科普的作用,因为Gumbel-Max仅在部分领域会用到,例如GAN、VAE等。笔者是在研究EMNLP上的一篇论文时,看到其中有用Gumbel-Softmax公式解决对一个概率分布进行采样无法求导的问题,故想到对Gumbel-Softmax做一个总结,由此写下本文

为什么我们需要Gumbel-Softmax ?

假设现在我们有一个离散随机变量

    Z
   
  
  
   Z
  
 
Z的分布

 
  
   
    
     
      p
     
     
      1
     
    
    
     =
    
    
     p
    
    
     (
    
    
     Z
    
    
     =
    
    
     1
    
    
     )
    
    
     =
    
    
     
      π
     
     
      1
     
    
    
    
     
      p
     
     
      2
     
    
    
     =
    
    
     p
    
    
     (
    
    
     Z
    
    
     =
    
    
     2
    
    
     )
    
    
     =
    
    
     
      π
     
     
      2
     
    
    
    
     
      p
     
     
      3
     
    
    
     =
    
    
     p
    
    
     (
    
    
     Z
    
    
     =
    
    
     3
    
    
     )
    
    
     =
    
    
     
      π
     
     
      3
     
    
    
    
     .
    
    
     .
    
    
     .
    
    
    
     
      p
     
     
      x
     
    
    
     =
    
    
     p
    
    
     (
    
    
     Z
    
    
     =
    
    
     x
    
    
     )
    
    
     =
    
    
     
      π
     
     
      x
     
    
    
   
   
     p_1 = p(Z=1)=\pi_1\\ p_2 = p(Z=2) = \pi_2\\ p_3 = p(Z=3) = \pi_3\\ ...\\ p_x = p(Z=x) = \pi_x\\ 
   
  
 p1​=p(Z=1)=π1​p2​=p(Z=2)=π2​p3​=p(Z=3)=π3​...px​=p(Z=x)=πx​

其中,

     ∑
    
    
     i
    
   
   
    
     π
    
    
     i
    
   
   
    =
   
   
    1
   
  
  
   \sum_i \pi_i=1
  
 
∑i​πi​=1。我们想根据

 
  
   
    
     p
    
    
     1
    
   
   
    ,
   
   
    
     p
    
    
     2
    
   
   
    ,
   
   
    .
   
   
    .
   
   
    .
   
   
    ,
   
   
    
     p
    
    
     x
    
   
  
  
   p_1,p_2,...,p_x
  
 
p1​,p2​,...,px​的概率采样得到一系列离散

 
  
   
    z
   
  
  
   z
  
 
z的值。但是这么做有一个问题,**我们采样出来的
 
  
   
    
     z
    
   
   
    z
   
  
 z只有值,没有生成
 
  
   
    
     z
    
   
   
    z
   
  
 z的式子**。例如我们要求

 
  
   
    Z
   
  
  
   Z
  
 
Z的期望,那么就有公式

 
  
   
    
     E
    
    
     (
    
    
     Z
    
    
     )
    
    
     =
    
    
     
      p
     
     
      1
     
    
    
     +
    
    
     2
    
    
     
      p
     
     
      2
     
    
    
     +
    
    
     ⋯
    
    
     +
    
    
     x
    
    
     
      p
     
     
      x
     
    
   
   
     \mathbb{E}(Z) = p_1 + 2p_2 + \cdots +xp_x 
   
  
 E(Z)=p1​+2p2​+⋯+xpx​


 
  
   
    Z
   
  
  
   Z
  
 
Z对

 
  
   
    
     p
    
    
     1
    
   
   
    ,
   
   
    
     p
    
    
     2
    
   
   
    ,
   
   
    .
   
   
    .
   
   
    .
   
   
    ,
   
   
    
     p
    
    
     x
    
   
  
  
   p_1,p_2,...,p_x
  
 
p1​,p2​,...,px​的导数都很清楚。但是现在我们的需求是采样一些具体的

 
  
   
    z
   
  
  
   z
  
 
z值,采样这个操作没有任何公式,因此也就无法求导。于是一个很自然的想法就产生了,我们能不能给一个**以
 
  
   
    
     
      p
     
     
      1
     
    
    
     ,
    
    
     
      p
     
     
      2
     
    
    
     ,
    
    
     .
    
    
     .
    
    
     .
    
    
     ,
    
    
     
      p
     
     
      z
     
    
   
   
    p_1,p_2,...,p_z
   
  
 p1​,p2​,...,pz​为参数的公式,让这个公式返回的结果是
 
  
   
    
     z
    
   
   
    z
   
  
 z采样的结果呢?**

Gumbel-Softmax

一般来说

     π
    
    
     i
    
   
  
  
   \pi_i
  
 
πi​是通过神经网络预测对于类别

 
  
   
    i
   
  
  
   i
  
 
i的概率,这在分类问题中非常常见,假设我们将一个样本送入模型,最后输出的概率分布为

 
  
   
    [
   
   
    0.2
   
   
    ,
   
   
    0.4
   
   
    ,
   
   
    0.1
   
   
    ,
   
   
    0.2
   
   
    ,
   
   
    0.1
   
   
    ]
   
  
  
   [0.2, 0.4,0.1,0.2,0.1]
  
 
[0.2,0.4,0.1,0.2,0.1],表明这是一个5分类问题,其中概率最大的是第2类,到这一步,我们直接通过argmax就能获得结果了,但现在我们不是预测问题,而是一个采样问题。对于模型来说,直接取出概率最大的就可以了,但对我们来说,每个类别都是有一定概率的,我们想根据这个概率来进行采样,而不是直接简单无脑的输出概率最大的值

最常见的采样

    z
   
  
  
   \mathbf{z}
  
 
z的onehot公式为

 
  
   
    
     
     
      
       
        z
       
       
        =
       
       
        onehot
       
       
        (
       
       
        max
       
       
        ⁡
       
       
        {
       
       
        i
       
       
        ∣
       
       
        
         π
        
        
         1
        
       
       
        +
       
       
        
         π
        
        
         2
        
       
       
        +
       
       
        ⋯
       
       
        +
       
       
        
         π
        
        
         
          i
         
         
          −
         
         
          1
         
        
       
       
        ≤
       
       
        u
       
       
        }
       
       
        )
       
      
     
     
     
      
       (1)
      
     
    
   
   
     \mathbf{z} = \text{onehot}(\max \{i\mid \pi_1 + \pi_2+\cdots +\pi_{i-1} \leq u\})\tag{1} 
   
  
 z=onehot(max{i∣π1​+π2​+⋯+πi−1​≤u})(1)

其中

    i
   
   
    =
   
   
    1
   
   
    ,
   
   
    2
   
   
    ,
   
   
    .
   
   
    .
   
   
    ,
   
   
    x
   
  
  
   i=1,2,..,x
  
 
i=1,2,..,x是类别的下标,随机变量

 
  
   
    u
   
  
  
   u
  
 
u服从均匀分布

 
  
   
    U
   
   
    (
   
   
    0
   
   
    ,
   
   
    1
   
   
    )
   
  
  
   U(0,1)
  
 
U(0,1)

上面这个过程实际上是很巧妙的,我们将概率分布从前往后不断加起来,当加到

      π
     
     
      i
     
    
   
   
    \pi_i
   
  
 πi​时超过了某个随机值$ 0\leq u \leq 1
 
  
   
    
     ,
    
    
     那
    
    
     么
    
    
     这
    
    
     一
    
    
     次
    
    
     随
    
    
     机
    
    
     采
    
    
     样
    
    
     过
    
    
     程
    
    
     ,
    
   
   
    ,那么这一次随机采样过程,
   
  
 ,那么这一次随机采样过程,z
 
  
   
    
     就
    
    
     被
    
    
     随
    
    
     机
    
    
     采
    
    
     样
    
    
     为
    
    
     第
    
   
   
    就被随机采样为第
   
  
 就被随机采样为第i$类,最后通过一个onehot变换

但是上述公式存在一个致命的问题:max函数是不可导的

Gumbel-Max Trick

Gumbel-Max技巧就是解决max函数不可导问题的,我们可以用argmax替换max,即

        z
       
       
        =
       
       
        onehot
       
       
        (
       
       
        
         
          argmax
         
        
        
         i
        
       
       
        {
       
       
        
         g
        
        
         i
        
       
       
        +
       
       
        log
       
       
        ⁡
       
       
        
         π
        
        
         i
        
       
       
        }
       
       
        )
       
      
     
     
     
      
       (2)
      
     
    
   
   
     \mathbf{z} = \text{onehot}(\mathop{\text{argmax}}\limits_{i} \{g_i + \log \pi_i\})\tag{2} 
   
  
 z=onehot(iargmax​{gi​+logπi​})(2)

其中,

     g
    
    
     i
    
   
   
    =
   
   
    −
   
   
    log
   
   
    ⁡
   
   
    (
   
   
    −
   
   
    log
   
   
    ⁡
   
   
    (
   
   
    
     u
    
    
     i
    
   
   
    )
   
   
    )
   
   
    ,
   
   
    
     u
    
    
     i
    
   
   
    ∼
   
   
    U
   
   
    (
   
   
    0
   
   
    ,
   
   
    1
   
   
    )
   
  
  
   g_i=-\log(-\log(u_i)), u_i \sim U(0,1)
  
 
gi​=−log(−log(ui​)),ui​∼U(0,1),这一项名为Gumbel噪声,或者叫Gumbel分布,目的是使得

 
  
   
    z
   
  
  
   \mathbf{z}
  
 
z的返回结果不固定

可以看到式

    (
   
   
    2
   
   
    )
   
  
  
   (2)
  
 
(2)的整个过程中,不可导的部分只有argmax,实际上我们可以用可导的softmax函数,在参数

 
  
   
    τ
   
  
  
   \tau
  
 
τ的控制下逼近argmax,最终

 
  
   
    
     z
    
    
     i
    
   
  
  
   z_i
  
 
zi​的公式为

 
  
   
    
     
     
      
       
        
         z
        
        
         i
        
       
       
        =
       
       
        
         
          exp
         
         
          ⁡
         
         
          (
         
         
          
           
            
             g
            
            
             i
            
           
           
            +
           
           
            log
           
           
            ⁡
           
           
            
             π
            
            
             i
            
           
          
          
           τ
          
         
         
          )
         
        
        
         
          
           ∑
          
          
           j
          
          
           x
          
         
         
          exp
         
         
          ⁡
         
         
          (
         
         
          
           
            
             g
            
            
             j
            
           
           
            +
           
           
            log
           
           
            ⁡
           
           
            
             π
            
            
             j
            
           
          
          
           τ
          
         
         
          )
         
        
       
      
     
     
     
      
       (3)
      
     
    
   
   
     z_i = \frac{\exp(\frac{g_i + \log \pi_i}{\tau})}{\sum_{j}^x\exp(\frac{g_j + \log \pi_j}{\tau})}\tag{3} 
   
  
 zi​=∑jx​exp(τgj​+logπj​​)exp(τgi​+logπi​​)​(3)

其中,

    τ
   
  
  
   \tau
  
 
τ越小

 
  
   
    (
   
   
    τ
   
   
    →
   
   
    0
   
   
    )
   
  
  
   (\tau \to 0)
  
 
(τ→0),整个softmax越光滑逼近argmax,并且

 
  
   
    z
   
   
    =
   
   
    {
   
   
    
     z
    
    
     i
    
   
   
    ∣
   
   
    i
   
   
    =
   
   
    1
   
   
    ,
   
   
    2
   
   
    ,
   
   
    .
   
   
    .
   
   
    .
   
   
    ,
   
   
    x
   
   
    }
   
  
  
   \mathbf{z} = \{z_i\mid i=1,2,...,x\}
  
 
z={zi​∣i=1,2,...,x}也越接近onehot向量;

 
  
   
    τ
   
  
  
   \tau
  
 
τ越大

 
  
   
    (
   
   
    τ
   
   
    →
   
   
    ∞
   
   
    )
   
  
  
   (\tau \to \infty)
  
 
(τ→∞),

 
  
   
    z
   
  
  
   \mathbf{z}
  
 
z向量越接近于均匀分布

总结

整个过程相当于我们把不可导的取样过程,从

    z
   
  
  
   \mathbf{z}
  
 
z本身转移到了求

 
  
   
    z
   
  
  
   \mathbf{z}
  
 
z的公式中的一项

 
  
   
    
     g
    
    
     i
    
   
  
  
   g_i
  
 
gi​中,而

 
  
   
    
     g
    
    
     i
    
   
  
  
   g_i
  
 
gi​本身不依赖

 
  
   
    
     p
    
    
     1
    
   
   
    ,
   
   
    .
   
   
    .
   
   
    ,
   
   
    
     p
    
    
     x
    
   
  
  
   p_1,..,p_x
  
 
p1​,..,px​,所以

 
  
   
    z
   
  
  
   z
  
 
z对

 
  
   
    
     p
    
    
     1
    
   
   
    ,
   
   
    .
   
   
    .
   
   
    .
   
   
    ,
   
   
    
     p
    
    
     x
    
   
  
  
   p_1,...,p_x
  
 
p1​,...,px​就可以到了,而且我们得到的

 
  
   
    z
   
  
  
   \mathbf{z}
  
 
z仍然是离散概率分布的采样。这种采样过程转嫁的技巧有一个专有名词,叫**重参数化技巧(Reparameterization Trick)**

References

  • What is Gumbel-Softmax
  • Gumbel-Softmax Trick和Gumbel分布

本文转载自: https://blog.csdn.net/qq_37236745/article/details/127890859
版权归原作者 数学家是我理想 所有, 如有侵权,请联系我们删除。

“Gumbel-Softmax完全解析”的评论:

还没有评论