0


Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)

目录

1 计算图原理

**计算图(Computational Graph)**是机器学习领域中推导神经网络和其他模型算法,以及软件编程实现的有效工具。

计算图的核心是将模型表示成一张**拓扑有序(Topologically Ordered)有向无环图(Directed Acyclic Graph)**,其中每个节点

     u
    
    
     i
    
   
  
  
   u_i
  
 
ui​包含数值信息(可以是标量、向量、矩阵或张量)和算子信息

 
  
   
    
     f
    
    
     i
    
   
  
  
   f_i
  
 
fi​。拓扑有序指当前节点仅在全体指向它的节点被计算后才进行计算。

在这里插入图片描述
计算图的优点在于:

  • 可以通过基本初等映射 的拓扑联结,形成复合的复杂模型,大多数神经网络模型都可以被计算图表示;
  • 便于实现**自动微分机(Automatic Differentiation Machine)**,对给定计算图可基于链式法则由节点局部梯度进行反向传播。

计算图的基本概念如表所示,基于计算图的基本前向传播和反向传播算法如表
符号含义

       n
      
     
     
      n
     
    
   n计算图的节点数
   
    
     
      
       l
      
     
     
      l
     
    
   l计算图的叶节点数
   
    
     
      
       L
      
     
     
      L
     
    
   L计算图的叶节点索引集
   
    
     
      
       C
      
     
     
      C
     
    
   C计算图的非叶节点索引集
   
    
     
      
       E
      
     
     
      E
     
    
   E计算图的有向边集合
   
    
     
      
       
        u
       
       
        i
       
      
     
     
      u_i
     
    
   ui​计算图中的第
   
    
     
      
       i
      
     
     
      i
     
    
   i节点或其值
   
    
     
      
       
        d
       
       
        i
       
      
     
     
      d_i
     
    
   di​
   
    
     
      
       
        u
       
       
        i
       
      
     
     
      u_i
     
    
   ui​ 的维度
   
    
     
      
       
        f
       
       
        i
       
      
     
     
      f_i
     
    
   fi​
   
    
     
      
       
        u
       
       
        i
       
      
     
     
      u_i
     
    
   ui​的算子
   
    
     
      
       
        α
       
       
        i
       
      
     
     
      \alpha _i
     
    
   αi​
   
    
     
      
       
        u
       
       
        i
       
      
     
     
      u_i
     
    
   ui​的全体关联输入
   
    
     
      
       
        J
       
       
        
         j
        
        
         →
        
        
         i
        
       
      
     
     
      \boldsymbol{J}_{j\rightarrow i}
     
    
   Jj→i​节点
   
    
     
      
       
        u
       
       
        i
       
      
     
     
      u_i
     
    
   ui​关于节点
   
    
     
      
       
        u
       
       
        j
       
      
     
     
      u_j
     
    
   uj​的雅克比矩阵
   
    
     
      
       
        P
       
       
        i
       
      
     
     
      \boldsymbol{P}_i
     
    
   Pi​输出节点关于输入节点的雅克比矩阵

2 基于计算图的传播

基于计算图的前向传播算法如下

在这里插入图片描述
基于计算图的反向传播算法如下

在这里插入图片描述

以第一节的图为例,可知

    E
   
   
    =
   
   
    
     {
    
    
     
      (
     
     
      1
     
     
      ,
     
     
      3
     
     
      )
     
    
    
     ,
    
    
     
      (
     
     
      2
     
     
      ,
     
     
      3
     
     
      )
     
    
    
     ,
    
    
     
      (
     
     
      2
     
     
      ,
     
     
      4
     
     
      )
     
    
    
     ,
    
    
     
      (
     
     
      3
     
     
      ,
     
     
      4
     
     
      )
     
    
    
     }
    
   
  
  
   E=\left\{ \left( 1,3 \right) ,\left( 2,3 \right) ,\left( 2,4 \right) ,\left( 3,4 \right) \right\}
  
 
E={(1,3),(2,3),(2,4),(3,4)}。首先进行前向传播:

 
  
   
    
     {
    
    
     
      
       
        
         
          
           u
          
          
           3
          
         
         
          =
         
         
          
           u
          
          
           1
          
         
         
          +
         
         
          
           u
          
          
           2
          
         
         
          =
         
         
          5
         
        
       
      
     
     
      
       
        
         
          
           u
          
          
           4
          
         
         
          =
         
         
          
           u
          
          
           2
          
         
         
          
           u
          
          
           3
          
         
         
          =
         
         
          15
         
        
       
      
     
    
   
   
    \begin{cases} u_3=u_1+u_2=5\\ u_4=u_2u_3=15\\\end{cases}
   
  
 {u3​=u1​+u2​=5u4​=u2​u3​=15​

 
  
   
    
     {
    
    
     
      
       
        
         
          
           J
          
          
           
            1
           
           
            →
           
           
            3
           
          
         
         
          =
         
         
          
           
            ∂
           
           
            
             u
            
            
             3
            
           
          
          
           /
          
          
           
            ∂
           
           
            
             u
            
            
             1
            
           
           
            =
           
          
         
         
          1
         
        
       
      
     
     
      
       
        
         
          
           J
          
          
           
            2
           
           
            →
           
           
            3
           
          
         
         
          =
         
         
          
           
            ∂
           
           
            
             u
            
            
             3
            
           
          
          
           /
          
          
           
            ∂
           
           
            
             u
            
            
             2
            
           
           
            =
           
          
         
         
          1
         
        
       
      
     
     
      
       
        
         
          
           J
          
          
           
            2
           
           
            →
           
           
            4
           
          
         
         
          =
         
         
          
           
            ∂
           
           
            
             u
            
            
             4
            
           
          
          
           /
          
          
           
            ∂
           
           
            
             u
            
            
             2
            
           
           
            =
           
          
         
         
          
           u
          
          
           3
          
         
         
          =
         
         
          5
         
        
       
      
     
     
      
       
        
         
          
           J
          
          
           
            3
           
           
            →
           
           
            4
           
          
         
         
          =
         
         
          
           
            ∂
           
           
            
             u
            
            
             4
            
           
          
          
           /
          
          
           
            ∂
           
           
            
             u
            
            
             3
            
           
           
            =
           
          
         
         
          
           u
          
          
           2
          
         
         
          =
         
         
          3
         
        
       
      
     
    
   
   
    \begin{cases} \boldsymbol{J}_{1\rightarrow 3}={{\partial u_3}/{\partial u_1=}}1\\ \boldsymbol{J}_{2\rightarrow 3}={{\partial u_3}/{\partial u_2=}}1\\ \boldsymbol{J}_{2\rightarrow 4}={{\partial u_4}/{\partial u_2=}}u_3=5\\ \boldsymbol{J}_{3\rightarrow 4}={{\partial u_4}/{\partial u_3=}}u_2=3\\\end{cases}
   
  
 ⎩⎨⎧​J1→3​=∂u3​/∂u1​=1J2→3​=∂u3​/∂u2​=1J2→4​=∂u4​/∂u2​=u3​=5J3→4​=∂u4​/∂u3​=u2​=3​

接着进行反向传播:

     {
    
    
     
      
       
        
         
          
           P
          
          
           4
          
         
         
          =
         
         
          1
         
        
       
      
     
     
      
       
        
         
          
           P
          
          
           3
          
         
         
          =
         
         
          
           P
          
          
           4
          
         
         
          
           J
          
          
           
            3
           
           
            →
           
           
            4
           
          
         
         
          =
         
         
          3
         
        
       
      
     
     
      
       
        
         
          
           P
          
          
           2
          
         
         
          =
         
         
          
           P
          
          
           4
          
         
         
          
           J
          
          
           
            2
           
           
            →
           
           
            4
           
          
         
         
          +
         
         
          
           P
          
          
           3
          
         
         
          
           J
          
          
           
            2
           
           
            →
           
           
            3
           
          
         
         
          =
         
         
          8
         
        
       
      
     
     
      
       
        
         
          
           P
          
          
           1
          
         
         
          =
         
         
          
           P
          
          
           3
          
         
         
          
           J
          
          
           
            1
           
           
            →
           
           
            3
           
          
         
         
          =
         
         
          3
         
        
       
      
     
    
   
   
    \begin{cases} \boldsymbol{P}_4=1\\ \boldsymbol{P}_3=\boldsymbol{P}_4\boldsymbol{J}_{3\rightarrow 4}=3\\ \boldsymbol{P}_2=\boldsymbol{P}_4\boldsymbol{J}_{2\rightarrow 4}+\boldsymbol{P}_3\boldsymbol{J}_{2\rightarrow 3}=8\\ \boldsymbol{P}_1=\boldsymbol{P}_3\boldsymbol{J}_{1\rightarrow 3}=3\\\end{cases}
   
  
 ⎩⎨⎧​P4​=1P3​=P4​J3→4​=3P2​=P4​J2→4​+P3​J2→3​=8P1​=P3​J1→3​=3​

3 神经网络计算图

一个神经网络的计算图实例如下,所有参数都可以用之前的模型表示

在这里插入图片描述

     L
    
    
     
      {
     
     
      
       
        
         
          
           
            u
           
           
            1
           
          
          
           =
          
          
           
            W
           
           
            1
           
          
          
           ∈
          
          
           
            R
           
           
            
             
              n
             
             
              1
             
            
            
             ×
            
            
             
              n
             
             
              0
             
            
           
          
         
        
       
      
      
       
        
         
          
           
            u
           
           
            2
           
          
          
           =
          
          
           
            b
           
           
            1
           
          
          
           ∈
          
          
           
            R
           
           
            
             n
            
            
             1
            
           
          
         
        
       
      
      
       
        
         
          
           
            u
           
           
            3
           
          
          
           =
          
          
           x
          
          
           ∈
          
          
           
            R
           
           
            
             n
            
            
             0
            
           
          
         
        
       
      
      
       
        
         
          
           
            u
           
           
            4
           
          
          
           =
          
          
           
            W
           
           
            2
           
          
          
           ∈
          
          
           
            R
           
           
            
             
              n
             
             
              2
             
            
            
             ×
            
            
             
              n
             
             
              1
             
            
           
          
         
        
       
      
      
       
        
         
          
           
            u
           
           
            5
           
          
          
           =
          
          
           
            b
           
           
            2
           
          
          
           ∈
          
          
           
            R
           
           
            
             n
            
            
             2
            
           
          
         
        
       
      
      
       
        
         
          
           
            u
           
           
            6
           
          
          
           =
          
          
           y
          
          
           ∈
          
          
           
            R
           
           
            
             n
            
            
             2
            
           
          
         
        
       
      
     
      
    
     C
    
    
     
      {
     
     
      
       
        
         
          
           
            u
           
           
            7
           
          
          
           =
          
          
           
            z
           
           
            1
           
          
          
           ∈
          
          
           
            R
           
           
            
             n
            
            
             1
            
           
          
          
           =
          
          
           
            W
           
           
            1
           
          
          
           x
          
          
           +
          
          
           
            b
           
           
            1
           
          
         
        
       
      
      
       
        
         
          
           
            u
           
           
            8
           
          
          
           =
          
          
           
            a
           
           
            1
           
          
          
           ∈
          
          
           
            R
           
           
            
             n
            
            
             1
            
           
          
          
           =
          
          
           σ
          
          
           
            (
           
           
            
             z
            
            
             1
            
           
           
            )
           
          
         
        
       
      
      
       
        
         
          
           
            u
           
           
            9
           
          
          
           =
          
          
           
            z
           
           
            2
           
          
          
           ∈
          
          
           
            R
           
           
            
             n
            
            
             2
            
           
          
          
           =
          
          
           
            W
           
           
            2
           
          
          
           
            a
           
           
            1
           
          
          
           +
          
          
           
            b
           
           
            2
           
          
         
        
       
      
      
       
        
         
          
           
            u
           
           
            10
           
          
          
           =
          
          
           y
          
          
           ∈
          
          
           
            R
           
           
            
             n
            
            
             2
            
           
          
          
           =
          
          
           σ
          
          
           
            (
           
           
            
             z
            
            
             2
            
           
           
            )
           
          
         
        
       
      
      
       
        
         
          
           
            u
           
           
            11
           
          
          
           =
          
          
           E
          
          
           ∈
          
          
           R
          
          
           =
          
          
           
            1
           
           
            2
           
          
          
           
            
             (
            
            
             y
            
            
             −
            
            
             
              
               y
              
              
               ~
              
             
            
            
             )
            
           
           
            T
           
          
          
           
            (
           
           
            y
           
           
            −
           
           
            
             
              y
             
             
              ~
             
            
           
           
            )
           
          
         
        
       
      
     
    
   
   
    L\begin{cases} u_1=\boldsymbol{W}^1\in \mathbb{R} ^{n_1\times n_0}\\ u_2=\boldsymbol{b}^1\in \mathbb{R} ^{n_1}\\ u_3=\boldsymbol{x}\in \mathbb{R} ^{n_0}\\ u_4=\boldsymbol{W}^2\in \mathbb{R} ^{n_2\times n_1}\\ u_5=\boldsymbol{b}^2\in \mathbb{R} ^{n_2}\\ u_6=\boldsymbol{y}\in \mathbb{R} ^{n_2}\\\end{cases}\,\, C\begin{cases} u_7=\boldsymbol{z}^1\in \mathbb{R} ^{n_1}=\boldsymbol{W}^1\boldsymbol{x}+\boldsymbol{b}^1\\ u_8=\boldsymbol{a}^1\in \mathbb{R} ^{n_1}=\sigma \left( \boldsymbol{z}^1 \right)\\ u_9=\boldsymbol{z}^2\in \mathbb{R} ^{n_2}=\boldsymbol{W}^2\boldsymbol{a}^1+\boldsymbol{b}^2\\ u_{10}=\boldsymbol{y}\in \mathbb{R} ^{n_2}=\sigma \left( \boldsymbol{z}^2 \right)\\ u_{11}=E\in \mathbb{R} =\frac{1}{2}\left( \boldsymbol{y}-\boldsymbol{\tilde{y}} \right) ^T\left( \boldsymbol{y}-\boldsymbol{\tilde{y}} \right)\\\end{cases}
   
  
 L⎩⎨⎧​u1​=W1∈Rn1​×n0​u2​=b1∈Rn1​u3​=x∈Rn0​u4​=W2∈Rn2​×n1​u5​=b2∈Rn2​u6​=y∈Rn2​​C⎩⎨⎧​u7​=z1∈Rn1​=W1x+b1u8​=a1∈Rn1​=σ(z1)u9​=z2∈Rn2​=W2a1+b2u10​=y∈Rn2​=σ(z2)u11​=E∈R=21​(y−y~​)T(y−y~​)​

4 自动微分机

自动微分机的基本原理是:

  • 跟踪记录从输入张量到输出张量的计算过程,并生成一幅前向传播计算图,计算图中的节点与张量一一对应
  • 基于计算图反向传播原理即可链式地求解输出节点关于各节点的梯度

必须指出,Pytorch不允许张量对张量求导,故输出节点必须是标量,通常为损失函数或输出向量的加权和;为节约内存,每次反向传播后Pytorch会自动释放前向传播计算图,即销毁中间计算节点的梯度和节点间的连接结构。

5 Pytorch中的自动微分

Tensor在自动微分机中的重要属性如表所示。
属性含义

device

该节点运行的设备环境,即CPU/GPU

requires_grad

自动微分机是否需要对该节点求导,缺省为False

grad

输出节点对该节点的梯度,缺省为None

grad_fn

中间计算节点关于全体输入节点的映射,记录了前向传播经过的操作。叶节点为None

is_leaf

该节点是否为叶节点
完成前向传播后,调用反向传播API即可更新各节点梯度,具体如下

backward(gradient=None, retain_graph=None, create_graph=None)

其中

  • gradient是权重向量,当输出节点 y y y不为标量时需指定与其同维的gradient,并以标量 g r a d i e n t T y gradient^Ty gradientTy为输出进行反向传播
  • retain_graph用于缓存前向传播计算图,可应用于一次传播测试多个损失函数等情形;
  • creat_graph用于构造导数计算图,可用于进一步求解高阶导数。

5.1 梯度缓存

中间计算节点的梯度需要通过

retain_grad()

方法进行缓存

w1 = torch.tensor([[2.],[3.]], requires_grad=True)
b1 = torch.tensor([1.], requires_grad=True)
x = torch.tensor([[10.],[20.]])

y = torch.mm(w1.transpose(0,1), x)+ b1
y.retain_grad()# 若不缓存则y.grad=None
out =3*y
out.backward()>> tensor([[30.],[60.]]) tensor([3.])None tensor([[3.]])

5.2 参数冻结

若希望冻结网络部分参数,只调整优化另一部分参数;或按顺序训练分支网络而屏蔽对主网络梯度的,可使用

detach()

方法从计算图中分离节点,阻断反向传播。分离的节点与原节点共享值内存,但不具有

grad

grad_fn

属性。

# 记第一层网络w1-b1为f,第二层网络w2-b2为g
w1 = torch.tensor([[2.],[3.]], requires_grad=True)
w2 = torch.tensor([3.], requires_grad=True)
b1 = torch.tensor([1.], requires_grad=True)
b2 = torch.tensor([2.], requires_grad=True)
x = torch.tensor([[10.],[20.]])

y = torch.mm(w1.transpose(0,1), x)+ b1
y_ = y.detach()
z = w2 * y_ + b2
out =3*z
out.backward()print(w1.grad, b1.grad, w2.grad, b2.grad)>>NoneNone tensor([243.]) tensor([3.])# f被冻结,梯度不更新# 若不使用detach冻结y之前的网络,则>> tensor([[90.],[180.]]) tensor([9.]) tensor([243.]) tensor([3.])

🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇


本文转载自: https://blog.csdn.net/FRIGIDWINTER/article/details/129290506
版权归原作者 Mr.Winter` 所有, 如有侵权,请联系我们删除。

“Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)”的评论:

还没有评论