0


深入浅出PyTorch中的nn.CrossEntropyLoss

👨‍💻 作者简介:非科班转码,正在不断丰富自己的技术栈
🗒️ 博客主页:https://raelum.blog.csdn.net
🎯 主要领域:NLP、RS、GNN
📢 如果这篇文章有帮助到你,可以关注❤️ + 点赞👍 + 收藏⭐ + 留言💬,这将是我创作的最大动力

在这里插入图片描述

目录

一、前言

nn.CrossEntropyLoss

常用作多分类问题的损失函数(对交叉熵还不了解的读者可以看我的这篇文章),本文将围绕PyTorch的官方文档对重要知识点进行逐一讲解(不会全部讲解)。

import torch
import torch.nn as nn

二、理论基础

对于

    C
    
   
    (
   
   
    C
   
   
    >
   
   
    2
   
   
    )
   
  
  
   C\,(C>2)
  
 
C(C>2) 分类问题,先不考虑 batch 的情形,设神经网络的输出(**还未经过 Softmax**)为 

 
  
   
    {
   
   
    
     x
    
    
     c
    
   
   
    
     }
    
    
     
      c
     
     
      =
     
     
      1
     
    
    
     C
    
   
  
  
   \{x_c\}_{c=1}^C
  
 
{xc​}c=1C​,经过 Softmax 后得到

 
  
   
    
     
      q
     
     
      i
     
    
    
     =
    
    
     
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        i
       
      
      
       )
      
     
     
      
       
        ∑
       
       
        
         c
        
        
         =
        
        
         1
        
       
       
        C
       
      
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        c
       
      
      
       )
      
     
    
   
   
     q_i=\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} 
   
  
 qi​=∑c=1C​exp(xc​)exp(xi​)​

从而该样本的交叉熵损失为

     H
    
    
     (
    
    
     p
    
    
     ,
    
    
     q
    
    
     )
    
    
     =
    
    
     −
    
    
     
      ∑
     
     
      
       i
      
      
       =
      
      
       1
      
     
     
      C
     
    
    
     
      p
     
     
      i
     
    
    
     log
    
    
     ⁡
    
    
     
      q
     
     
      i
     
    
    
     =
    
    
     −
    
    
     
      ∑
     
     
      
       i
      
      
       =
      
      
       1
      
     
     
      C
     
    
    
     
      p
     
     
      i
     
    
    
     log
    
    
     ⁡
    
    
     
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        i
       
      
      
       )
      
     
     
      
       
        ∑
       
       
        
         c
        
        
         =
        
        
         1
        
       
       
        C
       
      
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        c
       
      
      
       )
      
     
    
   
   
     H(p,q)=-\sum_{i=1}^C p_i\log q_i=-\sum_{i=1}^C p_i\log\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} 
   
  
 H(p,q)=−i=1∑C​pi​logqi​=−i=1∑C​pi​log∑c=1C​exp(xc​)exp(xi​)​

其中

    (
   
   
    
     p
    
    
     1
    
   
   
    ,
   
   
    
     p
    
    
     2
    
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    
     p
    
    
     C
    
   
   
    )
   
  
  
   (p_1,p_2,\cdots,p_C)
  
 
(p1​,p2​,⋯,pC​) 是 One-Hot 向量。

不妨令

     p
    
    
     y
    
   
   
    =
   
   
    1
    
   
    (
   
   
    y
   
   
    ∈
   
   
    {
   
   
    1
   
   
    ,
   
   
    2
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    C
   
   
    }
   
   
    )
   
  
  
   p_y=1\,(y\in\{1,2,\cdots,C\})
  
 
py​=1(y∈{1,2,⋯,C}),其余为 

 
  
   
    0
   
  
  
   0
  
 
0,因此上式变为

 
  
   
    
     H
    
    
     (
    
    
     p
    
    
     ,
    
    
     q
    
    
     )
    
    
     =
    
    
     −
    
    
     log
    
    
     ⁡
    
    
     
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        y
       
      
      
       )
      
     
     
      
       
        ∑
       
       
        
         c
        
        
         =
        
        
         1
        
       
       
        C
       
      
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        c
       
      
      
       )
      
     
    
   
   
     H(p,q)=-\log\frac{\exp(x_y)}{\sum_{c=1}^C\exp(x_c)} 
   
  
 H(p,q)=−log∑c=1C​exp(xc​)exp(xy​)​

现在考虑有 batch 的情形,不妨设 batch size 为

    N
   
  
  
   N
  
 
N,神经网络的输出为 

 
  
   
    {
   
   
    
     x
    
    
     
      n
     
     
      c
     
    
   
   
    
     }
    
    
     
      n
     
     
      c
     
    
   
   
    ,
     
   
    n
   
   
    =
   
   
    1
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    N
   
   
    ,
     
   
    c
   
   
    =
   
   
    1
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    C
   
  
  
   \{x_{nc}\}_{nc},\;n=1,\cdots,N,\;c=1,\cdots,C
  
 
{xnc​}nc​,n=1,⋯,N,c=1,⋯,C,第 

 
  
   
    n
   
  
  
   n
  
 
n 个样本的真实类别记为 

 
  
   
    
     y
    
    
     n
    
    
   
    (
   
   
    
     y
    
    
     n
    
   
   
    ∈
   
   
    {
   
   
    1
   
   
    ,
   
   
    2
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    C
   
   
    }
   
   
    )
   
  
  
   y_n\,(y_n\in\{1,2,\cdots,C\})
  
 
yn​(yn​∈{1,2,⋯,C}),第 

 
  
   
    n
   
  
  
   n
  
 
n 个样本的交叉熵损失记为 

 
  
   
    
     l
    
    
     n
    
   
  
  
   l_n
  
 
ln​,则仿照上式就有

 
  
   
    
     
      l
     
     
      n
     
    
    
     =
    
    
     −
    
    
     log
    
    
     ⁡
    
    
     
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        
         n
        
        
         ,
        
        
         
          y
         
         
          n
         
        
       
      
      
      
       )
      
     
     
      
       
        ∑
       
       
        
         c
        
        
         =
        
        
         1
        
       
       
        C
       
      
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        
         n
        
        
         c
        
       
      
      
       )
      
     
    
   
   
     l_n=-\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} 
   
  
 ln​=−log∑c=1C​exp(xnc​)exp(xn,yn​​)​

接下来我们讨论一些特殊情形。当数据不平衡时(某一类的样本数特别多,另一类的样本数特别少),我们需要为每一类的损失安排一个权重用来平衡。权重为

    w
   
   
    =
   
   
    (
   
   
    
     w
    
    
     1
    
   
   
    ,
   
   
    
     w
    
    
     2
    
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    
     w
    
    
     C
    
   
   
    )
   
  
  
   \boldsymbol{w}=(w_1,w_2,\cdots,w_C)
  
 
w=(w1​,w2​,⋯,wC​)。

📌 模型容易在样本数最多的一个(或几个)类上过拟合,因此对于那些样本数较少的类,我们需要设置更高的权重,这样模型在预测这些类的标签时一旦出错,就会受到更多的惩罚

安排了权重后,相应的损失为

      l
     
     
      n
     
    
    
     =
    
    
     −
    
    
     
      w
     
     
      
       y
      
      
       n
      
     
    
    
     log
    
    
     ⁡
    
    
     
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        
         n
        
        
         ,
        
        
         
          y
         
         
          n
         
        
       
      
      
      
       )
      
     
     
      
       
        ∑
       
       
        
         c
        
        
         =
        
        
         1
        
       
       
        C
       
      
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        
         n
        
        
         c
        
       
      
      
       )
      
     
    
   
   
     l_n=-w_{y_n}\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} 
   
  
 ln​=−wyn​​log∑c=1C​exp(xnc​)exp(xn,yn​​)​

计算完

     l
    
    
     1
    
   
   
    ,
   
   
    
     l
    
    
     2
    
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    
     l
    
    
     N
    
   
  
  
   l_1,l_2,\cdots,l_N
  
 
l1​,l2​,⋯,lN​ 后,我们既可以一次性将它们**全部返回**(对应 
reduction=none

),也可以返回它们的均值(对应

reduction=mean

),还可以返回它们的(对应

reduction=sum

):

     ℓ
    
    
     =
    
    
     
      {
     
     
      
       
        
         
          
           (
          
          
           
            l
           
           
            1
           
          
          
           ,
          
          
           ⋯
           
          
           ,
          
          
           
            l
           
           
            N
           
          
          
           )
          
          
           ,
          
         
        
       
       
        
         
          reduction=none
         
        
       
      
      
       
        
         
          
           
            ∑
           
           
            
             n
            
            
             =
            
            
             1
            
           
           
            N
           
          
          
           
            l
           
           
            n
           
          
          
           /
          
          
           
            ∑
           
           
            
             n
            
            
             =
            
            
             1
            
           
           
            N
           
          
          
           
            w
           
           
            
             y
            
            
             n
            
           
          
          
           ,
          
         
        
       
       
        
         
          reduction=mean
         
        
       
      
      
       
        
         
          
           
            ∑
           
           
            
             n
            
            
             =
            
            
             1
            
           
           
            N
           
          
          
           
            l
           
           
            n
           
          
          
           ,
          
         
        
       
       
        
         
          reduction=sum
         
        
       
      
     
    
   
   
     \ell=\begin{cases} (l_1,\cdots,l_N),&\text{reduction=none} \\ \sum_{n=1}^N l_n/\sum_{n=1}^N w_{y_n},&\text{reduction=mean} \\ \sum_{n=1}^N l_n,&\text{reduction=sum} \\ \end{cases} 
   
  
 ℓ=⎩⎪⎨⎪⎧​(l1​,⋯,lN​),∑n=1N​ln​/∑n=1N​wyn​​,∑n=1N​ln​,​reduction=nonereduction=meanreduction=sum​

在 NLP 任务中,我们往往将填充词元添加到每个序列的末尾,这样一来不同长度的序列可以进行批量加载。训练过程中,我们不希望网络预测出的填充词元被算入损失函数中。不妨设填充词元在词表中的索引为

    i
   
  
  
   i
  
 
i,则此时应对 

 
  
   
    
     l
    
    
     n
    
   
  
  
   l_n
  
 
ln​ 作如下修正:

 
  
   
    
     
      l
     
     
      n
     
    
    
     =
    
    
     −
    
    
     
      w
     
     
      
       y
      
      
       n
      
     
    
    
     ⋅
    
    
     I
    
    
     (
    
    
     
      y
     
     
      n
     
    
    
     ≠
    
    
     i
    
    
     )
    
    
     ⋅
    
    
     log
    
    
     ⁡
    
    
     
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        
         n
        
        
         ,
        
        
         
          y
         
         
          n
         
        
       
      
      
      
       )
      
     
     
      
       
        ∑
       
       
        
         c
        
        
         =
        
        
         1
        
       
       
        C
       
      
      
       exp
      
      
       ⁡
      
      
       (
      
      
       
        x
       
       
        
         n
        
        
         c
        
       
      
      
       )
      
     
    
    
     ,
    
    
    
     where
      
    
     I
    
    
     (
    
    
     x
    
    
     )
    
    
     =
    
    
     
      {
     
     
      
       
        
         
          
           1
          
          
           ,
          
         
        
       
       
        
         
          
           x
            
          
           is True
          
         
        
       
      
      
       
        
         
          
           0
          
          
           ,
          
         
        
       
       
        
         
          
           x
            
          
           is False
          
         
        
       
      
     
    
   
   
     l_n=-w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})},\qquad \text{where}\; \mathbb{I}(x)= \begin{cases} 1,&x\; \text{is True} \\ 0,&x\; \text{is False} \end{cases} 
   
  
 ln​=−wyn​​⋅I(yn​​=i)⋅log∑c=1C​exp(xnc​)exp(xn,yn​​)​,whereI(x)={1,0,​xis Truexis False​

另外,该场景下的

reduction=mean

对应的损失变为

     ℓ
    
    
     =
    
    
     
      ∑
     
     
      
       n
      
      
       =
      
      
       1
      
     
     
      N
     
    
    
     
      
       l
      
      
       n
      
     
     
      
       
        ∑
       
       
        
         n
        
        
         =
        
        
         1
        
       
       
        N
       
      
      
       
        w
       
       
        
         y
        
        
         n
        
       
      
      
       ⋅
      
      
       I
      
      
       (
      
      
       
        y
       
       
        n
       
      
      
       ≠
      
      
       i
      
      
       )
      
     
    
   
   
     \ell=\sum_{n=1}^N\frac{l_n}{\sum_{n=1}^Nw_{y_n}\cdot \mathbb{I}(y_n\neq i)} 
   
  
 ℓ=n=1∑N​∑n=1N​wyn​​⋅I(yn​​=i)ln​​

📌 **需要注意的是,在PyTorch中

       y
      
      
       n
      
     
     
      ∈
     
     
      {
     
     
      0
     
     
      ,
     
     
      1
     
     
      ,
     
     
      ⋯
      
     
      ,
     
     
      C
     
     
      −
     
     
      1
     
     
      }
     
    
    
     y_n\in\{0,1,\cdots,C-1\}
    
   
  yn​∈{0,1,⋯,C−1},这里我们之所以用 
  
   
    
     
      {
     
     
      1
     
     
      ,
     
     
      2
     
     
      ,
     
     
      ⋯
      
     
      ,
     
     
      C
     
     
      }
     
    
    
     \{1,2,\cdots,C\}
    
   
  {1,2,⋯,C} 是为了更自然地衔接上下文**

三、主要参数

nn.CrossEntropyLoss

的主要参数如下:

nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)

⚠️ **

size_average

reduce

参数已经弃用,取而代之的是

reduction

参数,所以这里不再讲解**


有了前面的铺垫,我们就可以很容易理解这些参数了:

  • weight:长度为 C C C 的张量,一般在数据不平衡时才会使用;
  • ignore_index:需要忽略的类别的索引,默认为 − 100 -100 −100,即不忽略;
  • reduction:决定以何种形式返回损失。为 none 时返回 N N N 个样本的损失,为 mean 时返回 N N N 个样本的损失均值,为 sum 时返回 N N N 个样本的损失的和。默认为 mean
  • label_smoothing:决定是否开启标签平滑(不了解标签平滑的读者可参考这篇文章),数值在 [ 0 , 1 ] [0,1] [0,1] 内。默认为 0 0 0,即不开启。

3.1 输入与输出

输入分为

input

target

input

通常为

    (
   
   
    N
   
   
    ,
   
   
    C
   
   
    )
   
  
  
   (N,C)
  
 
(N,C) 的形状(即 
batch_size × num_classes

),

target

通常为

    (
   
   
    N
   
   
    ,
   
   
    )
   
  
  
   (N,)
  
 
(N,) 的形状,其中的每个分量均位于 

 
  
   
    [
   
   
    0
   
   
    ,
   
   
    C
   
   
    −
   
   
    1
   
   
    ]
   
   
    ∩
   
   
    Z
   
  
  
   [0,C-1] \cap \mathbb{Z}
  
 
[0,C−1]∩Z 中,代表样本属于的类别。

📌 **

input

target

还可以是其他类型的输入,但本文只讨论这种使用最为广泛的输入**
📌 **

input

是神经网络的原始输出(未经过 Softmax),

nn.CrossEntropyLoss

会自动对其应用 Softmax**

torch.manual_seed(0)
batch_size =3
num_classes =5
criterion_1 = nn.CrossEntropyLoss(reduction='none')
criterion_2 = nn.CrossEntropyLoss()
criterion_3 = nn.CrossEntropyLoss(reduction='sum')

inputs = torch.randn(batch_size, num_classes)# 避免与input关键字冲突(当然这无所谓)
target = torch.randint(num_classes, size=(batch_size,))print(criterion_1(inputs, target))# 输出3个样本的loss# tensor([1.4639, 3.0493, 2.3056])print(criterion_2(inputs, target))# 输出3个样本的loss的均值# tensor(2.2729)print(criterion_3(inputs, target))# 输出3个样本的loss的和# tensor(6.8188)print(sum(criterion_1(inputs, target))== criterion_3(inputs, target))# tensor(True)print(sum(criterion_1(inputs, target))/ batch_size == criterion_2(inputs, target))# tensor(True)

四、从零开始实现 nn.CrossEntropyLoss

为了加深理解,接下来我们从零开始实现

nn.CrossEntropyLoss

(当然会和官方不同,为了追求可读性会采用傻瓜式实现)。

首先确定框架(为简便起见这里不考虑

label_smoothing

):

classCrossEntropyLoss(nn.Module):def__init__(self, weight=None, ignore_index=-100, reduction='mean'):super().__init__()
        self.weight = weight
        self.ignore_index = ignore_index
        self.reduction = reduction
        
    defforward(self, inputs, target):pass

为方便计算,我们对第二章节的损失计算公式进行改写

      l
     
     
      n
     
    
    
     =
    
    
     
      w
     
     
      
       y
      
      
       n
      
     
    
    
     ⋅
    
    
     I
    
    
     (
    
    
     
      y
     
     
      n
     
    
    
     ≠
    
    
     i
    
    
     )
    
    
     ⋅
    
    
     [
    
    
     −
    
    
     
      x
     
     
      
       n
      
      
       ,
      
      
       
        y
       
       
        n
       
      
     
    
    
     +
    
    
     log
    
    
     ⁡
    
    
     
      ∑
     
     
      
       c
      
      
       =
      
      
       1
      
     
     
      C
     
    
    
     exp
    
    
     ⁡
    
    
     (
    
    
     
      x
     
     
      
       n
      
      
       c
      
     
    
    
     )
    
    
     ]
    
   
   
     l_n=w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot[-x_{n,y_n}+\log\sum_{c=1}^C\exp(x_{nc})] 
   
  
 ln​=wyn​​⋅I(yn​​=i)⋅[−xn,yn​​+logc=1∑C​exp(xnc​)]

采用更符合 Python 的表述方式来改写上式

      l
     
     
      n
     
    
    
     =
    
    
     w
    
    
     [
    
    
     
      y
     
     
      n
     
    
    
     ]
    
    
     ⋅
    
    
     I
    
    
     (
    
    
     
      y
     
     
      n
     
    
    
     ≠
    
    
     i
    
    
     )
    
    
     ⋅
    
    
     [
    
    
     −
    
    
     
      x
     
     
      n
     
    
    
     [
    
    
     
      y
     
     
      n
     
    
    
     ]
    
    
     +
    
    
     log
    
    
     ⁡
    
    
     
      ∑
     
     
      
       c
      
      
       =
      
      
       1
      
     
     
      C
     
    
    
     exp
    
    
     ⁡
    
    
     (
    
    
     
      x
     
     
      n
     
    
    
     [
    
    
     c
    
    
     ]
    
    
     )
    
    
     ]
    
   
   
     l_n=\boldsymbol{w}[y_n]\cdot \mathbb{I}(y_n\neq i)\cdot[-\boldsymbol{x_n}[y_n]+\log\sum_{c=1}^C\exp(\boldsymbol{x_n}[c])] 
   
  
 ln​=w[yn​]⋅I(yn​​=i)⋅[−xn​[yn​]+logc=1∑C​exp(xn​[c])]

其中

    w
   
   
    =
   
   
    (
   
   
    
     w
    
    
     1
    
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    
     w
    
    
     C
    
   
   
    )
   
   
    ,
     
   
    
     x
    
    
     n
    
   
   
    =
   
   
    (
   
   
    
     x
    
    
     
      n
     
     
      1
     
    
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    
     x
    
    
     
      n
     
     
      C
     
    
   
   
    )
   
  
  
   \boldsymbol{w}=(w_1,\cdots,w_C),\;\boldsymbol{x_n}=(x_{n1},\cdots,x_{nC})
  
 
w=(w1​,⋯,wC​),xn​=(xn1​,⋯,xnC​)。再令 

 
  
   
    X
   
   
    =
   
   
    (
   
   
    
     x
    
    
     1
    
   
   
    ;
   
   
    ⋯
    
   
    ;
   
   
    
     x
    
    
     N
    
   
   
    )
   
   
    ,
     
   
    y
   
   
    =
   
   
    (
   
   
    
     y
    
    
     1
    
   
   
    ,
   
   
    ⋯
    
   
    ,
   
   
    
     y
    
    
     C
    
   
   
    )
   
  
  
   {\bf X}=(\boldsymbol{x_1};\cdots;\boldsymbol{x_N}),\;\boldsymbol{y}=(y_1,\cdots,y_C)
  
 
X=(x1​;⋯;xN​),y=(y1​,⋯,yC​),则显然 

 
  
   
    X
   
  
  
   {\bf X}
  
 
X 就是我们的 
input

    y
   
  
  
   \boldsymbol{y}
  
 
y 就是 
target

,于是我们可以进行批量计算

     (
    
    
     
      l
     
     
      1
     
    
    
     ,
    
    
     ⋯
     
    
     ,
    
    
     
      l
     
     
      N
     
    
    
     )
    
    
     =
    
    
     w
    
    
     [
    
    
     y
    
    
     ]
    
    
     ∗
    
    
     I
    
    
     (
    
    
     y
    
    
     ≠
    
    
     i
    
    
     )
    
    
     ∗
    
    
     (
    
    
     −
    
    
     X
    
    
     [
    
    
     range
    
    
     (
    
    
     len
    
    
     (
    
    
     y
    
    
     )
    
    
     )
    
    
     ,
     
    
     y
    
    
     ]
    
    
     +
    
    
     log
    
    
     ⁡
    
    
     (
    
    
     sum
    
    
     (
    
    
     exp
    
    
     ⁡
    
    
     (
    
    
     X
    
    
     )
    
    
     ,
     
    
     dim
    
    
     =
    
    
     1
    
    
     )
    
    
     )
    
    
     )
    
   
   
     (l_1,\cdots,l_N)=\boldsymbol{w}[\boldsymbol{y}] *\mathbb{I}(\boldsymbol{y}\neq i)* (-{\bf X}[\text{range}(\text{len}(\boldsymbol{y})),\,\boldsymbol{y}]+\log(\text{sum}(\exp({\bf X}),\,\text{dim}=1))) 
   
  
 (l1​,⋯,lN​)=w[y]∗I(y​=i)∗(−X[range(len(y)),y]+log(sum(exp(X),dim=1)))

其中

    ∗
   
  
  
   *
  
 
∗ 代表按元素相乘。上式采用了广播机制。
classCrossEntropyLoss(nn.Module):def__init__(self, weight=None, ignore_index=-100, reduction='mean'):super().__init__()
        self.weight = weight
        self.ignore_index = ignore_index
        self.reduction = reduction

    defforward(self, inputs, target):if self.weight isnotNone:
            n_samples_weight = self.weight[target]# 每个样本的权重else:
            n_samples_weight = torch.ones_like(target).float()# 不提供权重则默认全为1
        indicator =(target != self.ignore_index).long().float()# long()方法可以将布尔型张量转化成0-1张量
        raw_loss =-inputs[torch.arange(len(target)), target]+ torch.log(torch.sum(torch.exp(inputs), dim=1))
        result = n_samples_weight * indicator * raw_loss
        if self.reduction =='mean':return torch.sum(result)/ n_samples_weight.dot(indicator)elif self.reduction =='sum':return torch.sum(result)else:return result

输出结果与 PyTorch 官方的

nn.CrossEntropyLoss

的完全相同,这里不再展示,读者可自行验证。


本文转载自: https://blog.csdn.net/raelum/article/details/125588956
版权归原作者 aelum 所有, 如有侵权,请联系我们删除。

“深入浅出PyTorch中的nn.CrossEntropyLoss”的评论:

还没有评论