0


Gumbel-Softmax的logits输入可以是模型的输出

如下是Gumbel-Softmax的pytorch代码实现:

defgumbel_softmax(logits: torch.Tensor, tau:float=1, hard:bool=False, dim:int=-1)-> torch.Tensor:# _gumbels = (-torch.empty_like(#     logits,#     memory_format=torch.legacy_contiguous_format).exponential_().log()#             )  # ~Gumbel(0,1)# more stable https://github.com/pytorch/pytorch/issues/41663# example logits: [batch_size, n_class] unnormalized log-probs
    gumbel_dist = torch.distributions.gumbel.Gumbel(
        torch.tensor(0., device=logits.device, dtype=logits.dtype),
        torch.tensor(1., device=logits.device, dtype=logits.dtype))
    gumbels = gumbel_dist.sample(logits.shape)

    gumbels =(logits + gumbels)/ tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)if hard:# Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index,1.0)
        ret = y_hard - y_soft.detach()+ y_soft
    else:# Reparametrization trick.
        ret = y_soft
    return ret

在Gumbel-Softmax的使用中,可以使用未归一化的网络输出(即未经过 Softmax 处理也未经过 Log 处理)作为logits,这是因为Gumbel-Softmax的采样过程本质上依赖于logits的相对大小,而不绝对要求logits是概率的log值(为什么这样使用?个人认为是为了简化计算、数值稳定或提供更好的梯度性质)。以下是Gumbel-Softmax的公式,logits指的是

     log 
    
   
     ⁡ 
    
   
     ( 
    
   
     π 
    
   
     ) 
    
   
  
    \log(\pi) 
   
  
log(π), 
 
  
   
   
     π 
    
   
  
    \pi 
   
  
π指的是概率, 
 
  
   
   
     g 
    
   
  
    g 
   
  
g指的是Gumbel分布:

  
   
    
     
     
       y 
      
     
       i 
      
     
    
      = 
     
     
      
      
        exp 
       
      
        ⁡ 
       
      
        ( 
       
      
        ( 
       
      
        log 
       
      
        ⁡ 
       
      
        ( 
       
       
       
         π 
        
       
         i 
        
       
      
        ) 
       
      
        + 
       
       
       
         g 
        
       
         i 
        
       
      
        ) 
       
      
        / 
       
      
        τ 
       
      
        ) 
       
      
      
       
       
         ∑ 
        
        
        
          j 
         
        
          = 
         
        
          1 
         
        
       
         k 
        
       
      
        exp 
       
      
        ⁡ 
       
      
        ( 
       
      
        ( 
       
      
        log 
       
      
        ⁡ 
       
      
        ( 
       
       
       
         π 
        
       
         j 
        
       
      
        ) 
       
      
        + 
       
       
       
         g 
        
       
         j 
        
       
      
        ) 
       
      
        / 
       
      
        τ 
       
      
        ) 
       
      
     
    
   
     y_i = \frac{\exp((\log(\pi_i) + g_i) / \tau)}{\sum_{j=1}^k \exp((\log(\pi_j) + g_j) / \tau)} 
    
   
 yi​=∑j=1k​exp((log(πj​)+gj​)/τ)exp((log(πi​)+gi​)/τ)​

也可以参考Gumbel-Softmax官方代码的使用示例。

标签: 深度学习

本文转载自: https://blog.csdn.net/m0_46294481/article/details/139125402
版权归原作者 还好我不在意 所有, 如有侵权,请联系我们删除。

“Gumbel-Softmax的logits输入可以是模型的输出”的评论:

还没有评论