0


清华大学出品:罚梯度范数提高深度学习模型泛化性

1 引言

神经网络结构简单,训练样本量不足,则会导致训练出来的模型分类精度不高;神经网络结构复杂,训练样本量过大,则又会导致模型过拟合,所以如何训练神经网络提高模型的泛化性是人工智能领域一个非常核心的问题。最近读到了一篇与该问题相关的文章,论文中作者在训练过程中通过在损失函数中增加正则化项梯度范数的约束从而来提高深度学习模型的泛化性。作者从原理和实验两方面分别对论文中的方法进行了详细地阐述和验证。

     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz连续是对深度学习进行理论分析中非常重要且常见的数学工具,该论文就是以神经网络损失函数
 
  
   
    
     是
    
    
     
      L
     
     
      i
     
     
      p
     
     
      s
     
     
      c
     
     
      h
     
     
      i
     
     
      t
     
     
      z
     
    
   
   
    是\mathrm{Lipschitz}
   
  
 是Lipschitz连续为出发点进行数学推导。为了方便读者能够更流畅地欣赏论文作者漂亮的数学证明思路和过程,本文对于论文中没有展开数学证明细节进行了补充。


论文链接:https://arxiv.org/abs/2202.03599

2

     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     z
    
   
   
    \mathrm{Lipschiz}
   
  
 Lipschiz连续

给定一个训练数据集

     S
    
    
     =
    
    
     {
    
    
     (
    
    
     
      x
     
     
      i
     
    
    
     ,
    
    
     
      y
     
     
      i
     
    
    
     )
    
    
     
      }
     
     
      
       i
      
      
       =
      
      
       0
      
     
     
      n
     
    
   
   
    \mathcal{S}=\{(x_i,y_i)\}_{i=0}^n
   
  
 S={(xi​,yi​)}i=0n​服从分布
 
  
   
    
     D
    
   
   
    \mathcal{D}
   
  
 D,一个带有参数
 
  
   
    
     θ
    
    
     ∈
    
    
     Θ
    
   
   
    \theta \in \Theta
   
  
 θ∈Θ的神经网络
 
  
   
    
     f
    
    
     (
    
    
     ⋅
    
    
     ;
    
    
     θ
    
    
     )
    
   
   
    f(\cdot;\theta)
   
  
 f(⋅;θ),损失函数为
  
   
    
     
      
       L
      
      
       S
      
     
     
      =
     
     
      
       1
      
      
       N
      
     
     
      
       ∑
      
      
       
        i
       
       
        =
       
       
        1
       
      
      
       N
      
     
     
      l
     
     
      (
     
     
      
       
        
         y
        
        
         i
        
       
       
        ,
       
       
        
         y
        
        
         i
        
       
       
        ,
       
       
        θ
       
      
      
       ^
      
     
     
      )
     
    
    
     L_{\mathcal{S}}=\frac{1}{N}\sum\limits_{i=1}^N l(\hat{y_i,y_i ,\theta})
    
   
  LS​=N1​i=1∑N​l(yi​,yi​,θ^​)当需要对损失函数中的梯度范数进行约束时,则有如下损失函数
  
   
    
     
      L
     
     
      (
     
     
      θ
     
     
      )
     
     
      =
     
     
      
       L
      
      
       S
      
     
     
      +
     
     
      λ
     
     
      ⋅
     
     
      ∥
     
     
      
       ∇
      
      
       θ
      
     
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      )
     
     
      
       ∥
      
      
       p
      
     
    
    
     L(\theta)=L_{\mathcal{S}}+\lambda \cdot \|\nabla_\theta L_{\mathcal{S}}(\theta)\|_p
    
   
  L(θ)=LS​+λ⋅∥∇θ​LS​(θ)∥p​其中
 
  
   
    
     ∥
    
    
     ⋅
    
    
     
      ∥
     
     
      p
     
    
   
   
    \|\cdot \|_p
   
  
 ∥⋅∥p​表示
 
  
   
    
     p
    
   
   
    p
   
  
 p范数,
 
  
   
    
     λ
    
    
     ∈
    
    
     
      R
     
     
      +
     
    
   
   
    \lambda\in \mathbb{R}^{+}
   
  
 λ∈R+为梯度惩罚系数。一般情况下,损失函数引入梯度的正则化项会使得其在优化过程中在局部有更小的
 
  
   
    
     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz常数,
 
  
   
    
     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz常数越小,就意味着损失函数就越平滑,平损失函数平滑区域易于损失函数优化权重参数。进而会使得训练出来的深度学习模型有更好的泛化性。

 深度学习中一个非常重要而且常见的概念就是

     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz连续。给定一个空间
 
  
   
    
     Ω
    
    
     ⊂
    
    
     
      R
     
     
      n
     
    
   
   
    \Omega \subset \mathbb{R}^n
   
  
 Ω⊂Rn,对于函数
 
  
   
    
     h
    
    
     :
    
    
     Ω
    
    
     →
    
    
     
      R
     
     
      m
     
    
   
   
    h:\Omega \rightarrow \mathbb{R}^m
   
  
 h:Ω→Rm,如果存在一个常数
 
  
   
    
     K
    
   
   
    K
   
  
 K,对于
 
  
   
    
     ∀
    
    
     
      θ
     
     
      1
     
    
    
     ,
    
    
     
      θ
     
     
      2
     
    
    
     ∈
    
    
     Ω
    
   
   
    \forall \theta_1,\theta_2 \in \Omega
   
  
 ∀θ1​,θ2​∈Ω满足以下条件则称
 
  
   
    
     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz连续
  
   
    
     
      ∥
     
     
      h
     
     
      (
     
     
      
       θ
      
      
       1
      
     
     
      )
     
     
      −
     
     
      h
     
     
      (
     
     
      
       θ
      
      
       2
      
     
     
      )
     
     
      
       ∥
      
      
       2
      
     
     
      ≤
     
     
      K
     
     
      ⋅
     
     
      ∥
     
     
      
       θ
      
      
       1
      
     
     
      −
     
     
      
       θ
      
      
       2
      
     
     
      
       ∥
      
      
       2
      
     
    
    
     \|h(\theta_1)-h(\theta_2)\|_2 \le K \cdot \|\theta_1 - \theta_2\|_2
    
   
  ∥h(θ1​)−h(θ2​)∥2​≤K⋅∥θ1​−θ2​∥2​其中
 
  
   
    
     K
    
   
   
    K
   
  
 K表示的是
 
  
   
    
     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz常数。如果对于参数空间
 
  
   
    
     Θ
    
    
     ⊂
    
    
     Ω
    
   
   
    \Theta \subset \Omega
   
  
 Θ⊂Ω,如果
 
  
   
    
     Θ
    
   
   
    \Theta
   
  
 Θ有一个邻域
 
  
   
    
     A
    
   
   
    \mathcal{A}
   
  
 A,且
 
  
   
    
     h
    
    
     
      ∣
     
     
      A
     
    
   
   
    h|_{\mathcal{A}}
   
  
 h∣A​是
 
  
   
    
     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz连续,则称
 
  
   
    
     h
    
   
   
    h
   
  
 h是局部
 
  
   
    
     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz连续。直观来看,
 
  
   
    
     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz常数描述的是输出关于输入变化速率的一个上界。对于一个小的
 
  
   
    
     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     t
    
    
     z
    
   
   
    \mathrm{Lipschitz}
   
  
 Lipschitz参数,在邻域
 
  
   
    
     A
    
   
   
    \mathcal{A}
   
  
 A中给定任意两个点,它们输出的改变被限制在一个小的范围里。

 根据微分中值定理,给定一个最小值点

      θ
     
     
      i
     
    
   
   
    \theta_i
   
  
 θi​,对于任意点
 
  
   
    
     ∀
    
    
     
      θ
     
     
      i
     
     
      ′
     
    
    
     ∈
    
    
     A
    
   
   
    \forall \theta_i^{\prime}\in \mathcal{A}
   
  
 ∀θi′​∈A,则有如下公式成立 
  
   
    
     
      ∥
     
     
      ∣
     
     
      L
     
     
      (
     
     
      
       θ
      
      
       i
      
      
       ′
      
     
     
      )
     
     
      −
     
     
      L
     
     
      (
     
     
      
       θ
      
      
       i
      
     
     
      )
     
     
      
       ∥
      
      
       2
      
     
     
      =
     
     
      ∥
     
     
      ∇
     
     
      L
     
     
      (
     
     
      ζ
     
     
      )
     
     
      (
     
     
      
       θ
      
      
       i
      
      
       ′
      
     
     
      −
     
     
      
       θ
      
      
       i
      
     
     
      )
     
     
      
       ∥
      
      
       2
      
     
    
    
     \||L(\theta_i^{\prime})-L(\theta_i)\|_2 = \|\nabla L (\zeta) (\theta_i^{\prime}-\theta_i)\|_2
    
   
  ∥∣L(θi′​)−L(θi​)∥2​=∥∇L(ζ)(θi′​−θi​)∥2​其中
 
  
   
    
     ζ
    
    
     =
    
    
     c
    
    
     
      θ
     
     
      i
     
    
    
     +
    
    
     (
    
    
     1
    
    
     −
    
    
     c
    
    
     )
    
    
     
      θ
     
     
      i
     
     
      ′
     
    
    
     ,
    
    
     c
    
    
     ∈
    
    
     [
    
    
     0
    
    
     ,
    
    
     1
    
    
     ]
    
   
   
    \zeta=c \theta_i + (1-c)\theta^\prime_i, c \in [0,1]
   
  
 ζ=cθi​+(1−c)θi′​,c∈[0,1],根据
 
  
   
    
     C
    
    
     a
    
    
     u
    
    
     c
    
    
     h
    
    
     y
    
    
     -
    
    
     S
    
    
     c
    
    
     h
    
    
     w
    
    
     a
    
    
     r
    
    
     z
    
   
   
    \mathrm{Cauchy\text{-}Schwarz}
   
  
 Cauchy-Schwarz不等式可知
  
   
    
     
      ∥
     
     
      ∣
     
     
      L
     
     
      (
     
     
      
       θ
      
      
       i
      
      
       ′
      
     
     
      )
     
     
      −
     
     
      L
     
     
      (
     
     
      
       θ
      
      
       i
      
     
     
      )
     
     
      
       ∥
      
      
       2
      
     
     
      ≤
     
     
      ∥
     
     
      ∇
     
     
      L
     
     
      (
     
     
      ζ
     
     
      )
     
     
      
       ∥
      
      
       2
      
     
     
      ∥
     
     
      (
     
     
      
       θ
      
      
       i
      
      
       ′
      
     
     
      −
     
     
      
       θ
      
      
       i
      
     
     
      )
     
     
      
       ∥
      
      
       2
      
     
    
    
     \||L(\theta_i^{\prime})-L(\theta_i)\|_2 \le \|\nabla L (\zeta)\|_2 \|(\theta_i^{\prime}-\theta_i)\|_2
    
   
  ∥∣L(θi′​)−L(θi​)∥2​≤∥∇L(ζ)∥2​∥(θi′​−θi​)∥2​当
 
  
   
    
     
      θ
     
     
      i
     
     
      ′
     
    
    
     →
    
    
     θ
    
   
   
    \theta_i^{\prime}\rightarrow \theta
   
  
 θi′​→θ时,相应的
 
  
   
    
     L
    
    
     i
    
    
     p
    
    
     s
    
    
     c
    
    
     h
    
    
     i
    
    
     z
    
   
   
    \mathrm{Lipschiz}
   
  
 Lipschiz常数接近
 
  
   
    
     ∥
    
    
     ∇
    
    
     L
    
    
     (
    
    
     
      θ
     
     
      i
     
    
    
     )
    
    
     
      ∥
     
     
      2
     
    
   
   
    \|\nabla L(\theta_i)\|_2
   
  
 ∥∇L(θi​)∥2​。因此可以通过减小
 
  
   
    
     ∥
    
    
     ∇
    
    
     L
    
    
     (
    
    
     
      θ
     
     
      i
     
    
    
     )
    
    
     ∥
    
   
   
    \|\nabla L(\theta_i)\|
   
  
 ∥∇L(θi​)∥的数值使得模型能够更平滑的收敛。

3 论文方法

对带有梯度范数约束的损失函数求梯度可得

       ∇
      
      
       θ
      
     
     
      L
     
     
      (
     
     
      θ
     
     
      )
     
     
      =
     
     
      
       ∇
      
      
       θ
      
     
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      )
     
     
      +
     
     
      
       ∇
      
      
       θ
      
     
     
      (
     
     
      λ
     
     
      ⋅
     
     
      ∥
     
     
      
       ∇
      
      
       θ
      
     
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      )
     
     
      
       ∥
      
      
       p
      
     
     
      )
     
    
    
     \nabla_\theta L(\theta)=\nabla_\theta L_{\mathcal{S}}(\theta)+\nabla_\theta(\lambda \cdot \|\nabla_\theta L_{\mathcal{S}}(\theta)\|_p)
    
   
  ∇θ​L(θ)=∇θ​LS​(θ)+∇θ​(λ⋅∥∇θ​LS​(θ)∥p​)在本文中,作者令
 
  
   
    
     p
    
    
     =
    
    
     2
    
   
   
    p=2
   
  
 p=2,此时则有如下推导过程
  
   
    
     
      
       
        
         
          
           ∇
          
          
           θ
          
         
         
          ∥
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          
           ∥
          
          
           2
          
         
        
       
      
      
       
        
         
         
          =
         
         
          
           ∇
          
          
           θ
          
         
         
          [
         
         
          
           ∇
          
          
           θ
          
          
           ⊤
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          ⋅
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          
           ]
          
          
           
            1
           
           
            2
           
          
         
        
       
      
     
     
      
       
        
       
      
      
       
        
         
         
          =
         
         
          
           1
          
          
           2
          
         
         
          ⋅
         
         
          
           ∇
          
          
           θ
          
          
           2
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          
           
            
             ∇
            
            
             θ
            
           
           
            
             L
            
            
             S
            
           
           
            (
           
           
            θ
           
           
            )
           
          
          
           
            ∥
           
           
            
             ∇
            
            
             θ
            
           
           
            
             L
            
            
             S
            
           
           
            (
           
           
            θ
           
           
            )
           
           
            
             ∥
            
            
             2
            
           
          
         
        
       
      
     
    
    
     \begin{aligned}\nabla_\theta \|\nabla_\theta L_\mathcal{S}(\theta)\|_2&=\nabla_\theta[\nabla^{\top}_\theta L_{\mathcal{S}}(\theta)\cdot \nabla_\theta L_\mathcal{S}(\theta)]^{\frac{1}{2}}\\&=\frac{1}{2}\cdot \nabla^2_\theta L_{\mathcal{S}}(\theta)\frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2}\end{aligned}
    
   
  ∇θ​∥∇θ​LS​(θ)∥2​​=∇θ​[∇θ⊤​LS​(θ)⋅∇θ​LS​(θ)]21​=21​⋅∇θ2​LS​(θ)∥∇θ​LS​(θ)∥2​∇θ​LS​(θ)​​将该结果带入到梯度范数约束的损失函数中,则有以下公式

  
   
    
     
      
       ∇
      
      
       θ
      
     
     
      L
     
     
      (
     
     
      θ
     
     
      )
     
     
      =
     
     
      
       ∇
      
      
       θ
      
     
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      )
     
     
      +
     
     
      λ
     
     
      ⋅
     
     
      
       ∇
      
      
       θ
      
      
       2
      
     
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      )
     
     
      
       
        
         ∇
        
        
         θ
        
       
       
        
         L
        
        
         S
        
       
       
        (
       
       
        θ
       
       
        )
       
      
      
       
        ∥
       
       
        
         ∇
        
        
         θ
        
       
       
        
         L
        
        
         S
        
       
       
        (
       
       
        θ
       
       
        )
       
       
        
         ∥
        
        
         2
        
       
      
     
    
    
     \nabla_\theta L(\theta)=\nabla_\theta L_{\mathcal{S}}(\theta)+\lambda \cdot \nabla^2_\theta L_{\mathcal{S}}(\theta) \frac{\nabla_\theta L_{\mathcal{S}}(\theta)}{\|\nabla_\theta L_{\mathcal{S}}(\theta)\|_2}
    
   
  ∇θ​L(θ)=∇θ​LS​(θ)+λ⋅∇θ2​LS​(θ)∥∇θ​LS​(θ)∥2​∇θ​LS​(θ)​可以发现,以上公式中涉及到
 
  
   
    
     H
    
    
     e
    
    
     s
    
    
     s
    
    
     i
    
    
     a
    
    
     n
    
   
   
    \mathrm{Hessian}
   
  
 Hessian矩阵的计算,在深度学习中,计算参数的
 
  
   
    
     H
    
    
     e
    
    
     s
    
    
     s
    
    
     i
    
    
     a
    
    
     n
    
   
   
    \mathrm{Hessian}
   
  
 Hessian矩阵会带来高昂的计算成本,所以需要用到一些近似的方法。作者将损失函数进行泰勒展开,其中令
 
  
   
    
     H
    
    
     =
    
    
     
      ∇
     
     
      θ
     
     
      2
     
    
    
     
      L
     
     
      S
     
    
    
     (
    
    
     θ
    
    
     )
    
   
   
    H=\nabla^2_\theta L_\mathcal{S}(\theta)
   
  
 H=∇θ2​LS​(θ),则有
  
   
    
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      +
     
     
      Δ
     
     
      θ
     
     
      )
     
     
      =
     
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      )
     
     
      +
     
     
      
       ∇
      
      
       θ
      
      
       ⊤
      
     
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      )
     
     
      ⋅
     
     
      Δ
     
     
      θ
     
     
      +
     
     
      
       1
      
      
       2
      
     
     
      Δ
     
     
      
       θ
      
      
       ⊤
      
     
     
      H
     
     
      Δ
     
     
      θ
     
     
      +
     
     
      O
     
     
      (
     
     
      ∥
     
     
      Δ
     
     
      θ
     
     
      
       ∥
      
      
       2
      
      
       2
      
     
     
      )
     
    
    
     L_\mathcal{S}(\theta+\Delta \theta)=L_\mathcal{S}(\theta)+\nabla^{\top}_{\theta}L_\mathcal{S}(\theta)\cdot \Delta \theta + \frac{1}{2} \Delta \theta^{\top} H \Delta \theta +\mathcal{O}(\|\Delta \theta\|_2^2)
    
   
  LS​(θ+Δθ)=LS​(θ)+∇θ⊤​LS​(θ)⋅Δθ+21​Δθ⊤HΔθ+O(∥Δθ∥22​)进而则有
  
   
    
     
      
       
        
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          +
         
         
          Δ
         
         
          θ
         
         
          )
         
        
       
      
      
       
        
         
         
          =
         
         
          
           ∇
          
          
           
            Δ
           
           
            θ
           
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          +
         
         
          Δ
         
         
          θ
         
         
          )
         
         
          =
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          +
         
         
          H
         
         
          Δ
         
         
          θ
         
         
          +
         
         
          O
         
         
          (
         
         
          ∥
         
         
          Δ
         
         
          θ
         
         
          
           ∥
          
          
           2
          
          
           2
          
         
         
          )
         
        
       
      
     
    
    
     \begin{aligned}\nabla_\theta L_\mathcal{S}(\theta+\Delta \theta)&=\nabla_{\Delta\theta} L_\mathcal{S} (\theta + \Delta\theta)=\nabla_\theta L_{\mathcal{S}}(\theta)+ H \Delta \theta + \mathcal{O}(\|\Delta \theta\|^2_2)\end{aligned}
    
   
  ∇θ​LS​(θ+Δθ)​=∇Δθ​LS​(θ+Δθ)=∇θ​LS​(θ)+HΔθ+O(∥Δθ∥22​)​其中令
 
  
   
    
     Δ
    
    
     θ
    
    
     =
    
    
     r
    
    
     v
    
   
   
    \Delta \theta=r v
   
  
 Δθ=rv,
 
  
   
    
     r
    
   
   
    r
   
  
 r表示一个小的数值,
 
  
   
    
     v
    
   
   
    v
   
  
 v表示一个向量,带入上式则有
  
   
    
     
      H
     
     
      v
     
     
      =
     
     
      
       
        
         ∇
        
        
         θ
        
       
       
        
         L
        
        
         S
        
       
       
        (
       
       
        θ
       
       
        +
       
       
        r
       
       
        v
       
       
        )
       
       
        −
       
       
        
         ∇
        
        
         θ
        
       
       
        
         L
        
        
         S
        
       
       
        (
       
       
        θ
       
       
        )
       
      
      
       r
      
     
     
      +
     
     
      O
     
     
      (
     
     
      r
     
     
      )
     
    
    
     H v =\frac{\nabla_\theta L_{\mathcal{S}}(\theta + r v)-\nabla_\theta L_{\mathcal{S}}(\theta)}{r}+\mathcal{O}(r)
    
   
  Hv=r∇θ​LS​(θ+rv)−∇θ​LS​(θ)​+O(r)如果令
 
  
   
    
     v
    
    
     =
    
    
     
      
       
        ∇
       
       
        θ
       
      
      
       
        L
       
       
        S
       
      
      
       (
      
      
       θ
      
      
       )
      
     
     
      
       ∥
      
      
       
        ∇
       
       
        θ
       
      
      
       
        L
       
       
        S
       
      
      
       (
      
      
       θ
      
      
       )
      
      
       ∥
      
     
    
   
   
    v=\frac{\nabla_{\theta}L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|}
   
  
 v=∥∇θ​LS​(θ)∥∇θ​LS​(θ)​,则有
  
   
    
     
      H
     
     
      
       
        
         ∇
        
        
         θ
        
       
       
        
         L
        
        
         S
        
       
       
        (
       
       
        θ
       
       
        )
       
      
      
       
        ∥
       
       
        
         ∇
        
        
         θ
        
       
       
        
         L
        
        
         S
        
       
       
        (
       
       
        θ
       
       
        )
       
       
        
         ∥
        
        
         2
        
       
      
     
     
      ≈
     
     
      
       
        
         ∇
        
        
         θ
        
       
       
        L
       
       
        (
       
       
        θ
       
       
        +
       
       
        r
       
       
        
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
        
        
         
          ∥
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          
           ∥
          
          
           2
          
         
        
       
       
        )
       
       
        −
       
       
        
         ∇
        
        
         θ
        
       
       
        L
       
       
        (
       
       
        θ
       
       
        )
       
      
      
       r
      
     
    
    
     H \frac{\nabla_{\theta}L_{\mathcal{S}}(\theta)}{\|\nabla_\theta L_{\mathcal{S}}(\theta)\|_2}\approx \frac{\nabla_\theta L(\theta + r\frac{\nabla_\theta L_{\mathcal{S}}(\theta)}{\|\nabla_\theta L_{\mathcal{S}}(\theta)\|_2})-\nabla_\theta L(\theta)}{r}
    
   
  H∥∇θ​LS​(θ)∥2​∇θ​LS​(θ)​≈r∇θ​L(θ+r∥∇θ​LS​(θ)∥2​∇θ​LS​(θ)​)−∇θ​L(θ)​

综上所述,经过整理可得

           ∇
          
          
           θ
          
         
         
          L
         
         
          (
         
         
          θ
         
         
          )
         
        
       
      
      
       
        
         
         
          =
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          +
         
         
          
           λ
          
          
           r
          
         
         
          ⋅
         
         
          (
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          +
         
         
          r
         
         
          
           
            
             ∇
            
            
             θ
            
           
           
            
             L
            
            
             S
            
           
           
            (
           
           
            θ
           
           
            )
           
          
          
           
            ∥
           
           
            
             ∇
            
            
             θ
            
           
           
            
             L
            
            
             S
            
           
           
            (
           
           
            θ
           
           
            )
           
           
            
             ∥
            
            
             2
            
           
          
         
         
          )
         
         
          −
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          )
         
        
       
      
     
     
      
       
        
       
      
      
       
        
         
         
          =
         
         
          (
         
         
          1
         
         
          −
         
         
          α
         
         
          )
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          +
         
         
          α
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          +
         
         
          r
         
         
          
           
            
             ∇
            
            
             θ
            
           
           
            
             L
            
            
             S
            
           
           
            (
           
           
            θ
           
           
            )
           
          
          
           
            ∥
           
           
            
             ∇
            
            
             θ
            
           
           
            
             L
            
            
             S
            
           
           
            (
           
           
            θ
           
           
            )
           
           
            
             ∥
            
            
             2
            
           
          
         
         
          )
         
        
       
      
     
    
    
     \begin{aligned}\nabla_\theta L(\theta)&=\nabla_\theta L_\mathcal{S} (\theta)+\frac{\lambda}{r}\cdot (\nabla_\theta L_{\mathcal{S}}(\theta + r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2})-\nabla_\theta L_\mathcal{S}(\theta))\\&=(1-\alpha)\nabla_\theta L_\mathcal{S} (\theta)+\alpha \nabla_\theta L_\mathcal{S}(\theta+r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2})\end{aligned}
    
   
  ∇θ​L(θ)​=∇θ​LS​(θ)+rλ​⋅(∇θ​LS​(θ+r∥∇θ​LS​(θ)∥2​∇θ​LS​(θ)​)−∇θ​LS​(θ))=(1−α)∇θ​LS​(θ)+α∇θ​LS​(θ+r∥∇θ​LS​(θ)∥2​∇θ​LS​(θ)​)​其中
 
  
   
    
     α
    
    
     =
    
    
     
      λ
     
     
      r
     
    
   
   
    \alpha=\frac{\lambda}{r}
   
  
 α=rλ​,称
 
  
   
    
     α
    
   
   
    \alpha
   
  
 α为平衡系数,取值范围为
 
  
   
    
     0
    
    
     ≤
    
    
     α
    
    
     ≤
    
    
     1
    
   
   
    0 \le \alpha \le 1
   
  
 0≤α≤1。作者为了避免在近似计算梯度时,以上公式中的第二项链式法则求梯度需要计算
 
  
   
    
     H
    
    
     e
    
    
     s
    
    
     s
    
    
     i
    
    
     a
    
    
     n
    
   
   
    \mathrm{Hessian}
   
  
 Hessian矩阵,做了以下的近似则有
  
   
    
     
      
       ∇
      
      
       θ
      
     
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      +
     
     
      r
     
     
      
       
        
         ∇
        
        
         θ
        
       
       
        
         L
        
        
         S
        
       
       
        (
       
       
        θ
       
       
        )
       
      
      
       
        ∥
       
       
        
         ∇
        
        
         θ
        
       
       
        
         L
        
        
         S
        
       
       
        (
       
       
        θ
       
       
        )
       
       
        
         ∥
        
        
         2
        
       
      
     
     
      )
     
     
      ≈
     
     
      
       ∇
      
      
       θ
      
     
     
      
       L
      
      
       S
      
     
     
      (
     
     
      θ
     
     
      )
     
     
      
       ∣
      
      
       
        θ
       
       
        =
       
       
        θ
       
       
        +
       
       
        r
       
       
        
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
        
        
         
          ∥
         
         
          
           ∇
          
          
           θ
          
         
         
          
           L
          
          
           S
          
         
         
          (
         
         
          θ
         
         
          )
         
         
          
           ∥
          
          
           2
          
         
        
       
      
     
    
    
     \nabla_\theta L_\mathcal{S}(\theta+r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2})\approx \nabla_\theta L_\mathcal{S} (\theta)|_{\theta =\theta +r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2}}
    
   
  ∇θ​LS​(θ+r∥∇θ​LS​(θ)∥2​∇θ​LS​(θ)​)≈∇θ​LS​(θ)∣θ=θ+r∥∇θ​LS​(θ)∥2​∇θ​LS​(θ)​​以下算法流程图对本论文的训练方法进行汇总

4 实验结果

下表表示的是在

     C
    
    
     i
    
    
     f
    
    
     a
    
    
     r
    
    
     10
    
   
   
    \mathrm{Cifar10}
   
  
 Cifar10和
 
  
   
    
     C
    
    
     i
    
    
     f
    
    
     a
    
    
     r
    
    
     100
    
   
   
    \mathrm{Cifar100}
   
  
 Cifar100这两个数据集中不同
 
  
   
    
     C
    
    
     N
    
    
     N
    
   
   
    \mathrm{CNN}
   
  
 CNN网络结构在标准训练,
 
  
   
    
     S
    
    
     A
    
    
     M
    
   
   
    \mathrm{SAM}
   
  
 SAM和本文的梯度约束这三种训练方法之间的测试错误率的比较。可以很直观的发现,本文提出的方法在绝大多数情况下测试错误率都是最低的,这也从侧面验证了经过论文方法的训练可以提高
 
  
   
    
     C
    
    
     N
    
    
     N
    
   
   
    \mathrm{CNN}
   
  
 CNN模型的泛化性。


论文作者也在当前非常热门的网络结构

      V
     
     
      i
     
     
      s
     
     
      i
     
     
      o
     
     
      n
     
     
       
     
     
      T
     
     
      r
     
     
      a
     
     
      n
     
     
      s
     
     
      f
     
     
      o
     
     
      r
     
     
      m
     
     
      e
     
     
      r
     
    
    
     \mathrm{Vision \text{ } Transformer}
    
   
  Vision Transformer进行了实验。下表表示的是在
  
   
    
     
      C
     
     
      i
     
     
      f
     
     
      a
     
     
      r
     
     
      10
     
    
    
     \mathrm{Cifar10}
    
   
  Cifar10和
  
   
    
     
      C
     
     
      i
     
     
      f
     
     
      a
     
     
      r
     
     
      100
     
    
    
     \mathrm{Cifar100}
    
   
  Cifar100这两个数据集中不同
  
   
    
     
      V
     
     
      i
     
     
      T
     
    
    
     \mathrm{ViT}
    
   
  ViT网络结构在标准训练,
  
   
    
     
      S
     
     
      A
     
     
      M
     
    
    
     \mathrm{SAM}
    
   
  SAM和本文的梯度约束这三种训练方法之间的测试错误率的比较。同理也可以发现本文提出的方法在所有情况下测试错误率都是最低的,这说明本文的方法也可以提到
  
   
    
     
      V
     
     
      i
     
     
      s
     
     
      i
     
     
      o
     
     
      n
     
     
       
     
     
      t
     
     
      r
     
     
      a
     
     
      n
     
     
      s
     
     
      f
     
     
      o
     
     
      r
     
     
      m
     
     
      e
     
     
      r
     
    
    
     \mathrm{Vision \text{ } transformer}
    
   
  Vision transformer模型的泛化性。


本文转载自: https://blog.csdn.net/qq_38406029/article/details/122851202
版权归原作者 鬼道2022 所有, 如有侵权,请联系我们删除。

“清华大学出品:罚梯度范数提高深度学习模型泛化性”的评论:

还没有评论