0


Vision Transformer详解(附代码)

1 引言

     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer}
   
  
 Transformer在
 
  
   
    
     N
    
    
     L
    
    
     P
    
   
   
    \mathrm{NLP}
   
  
 NLP中大获成功,
  
   
    
     
      V
     
     
      i
     
     
      s
     
     
      i
     
     
      o
     
     
      n
     
     
       
     
     
      T
     
     
      r
     
     
      a
     
     
      n
     
     
      s
     
     
      f
     
     
      o
     
     
      r
     
     
      m
     
     
      e
     
     
      r
     
    
    
     \mathrm{Vision\text{ }Transformer}
    
   
  Vision Transformer则将
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer}
   
  
 Transformer模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer}
   
  
 Transformer中的注意力机制可以综合考量全局的特征信息。
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer尽力做到在不改变
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer}
   
  
 Transformer中
 
  
   
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Encoder}
   
  
 Encoder架构的前提下,直接将其从
 
  
   
    
     N
    
    
     L
    
    
     P
    
   
   
    \mathrm{NLP}
   
  
 NLP领域迁移到计算机视觉领域中,目的是让原始的
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer}
   
  
 Transformer模型开箱即用。如果想要了解
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer}
   
  
 Transformer原理详细的介绍可以看我的上一篇文章《Transformer详解(附代码)》。

2 注意力机制应用

在正式详细介绍

     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer之前,先介绍两个注意力机制在计算机视觉中应用的例子。
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer并不是第一个将注意力机制应用到计算机视觉的领域中去的,其中
  
   
    
     
      S
     
     
      A
     
     
      G
     
     
      A
     
     
      N
     
    
    
     \mathrm{SAGAN}
    
   
  SAGAN和
  
   
    
     
      A
     
     
      t
     
     
      t
     
     
      n
     
     
      G
     
     
      A
     
     
      N
     
    
    
     \mathrm{AttnGAN}
    
   
  AttnGAN就早已经在
 
  
   
    
     G
    
    
     A
    
    
     N
    
   
   
    \mathrm{GAN}
   
  
 GAN的框架中引入了注意力机制,并且它们大大提高了图像生成的质量。

2.1 Self-Attention GAN

     S
    
    
     A
    
    
     G
    
    
     A
    
    
     N
    
   
   
    \mathrm{SAGAN}
   
  
 SAGAN在
 
  
   
    
     G
    
    
     A
    
    
     N
    
   
   
    \mathrm{GAN}
   
  
 GAN的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。
 
  
   
    
     S
    
    
     A
    
    
     G
    
    
     A
    
    
     N
    
   
   
    \mathrm{SAGAN}
   
  
 SAGAN中自注意力机制的操作原理如下图所示。

给定一个

      3
     
    
    
     3
    
   
  3通道的输入特征图
  
   
    
     
      X
     
     
      =
     
     
      (
     
     
      
       X
      
      
       1
      
     
     
      ,
     
     
      
       X
      
      
       2
      
     
     
      ,
     
     
      
       X
      
      
       3
      
     
     
      )
     
     
      ∈
     
     
      
       R
      
      
       
        3
       
       
        ×
       
       
        3
       
       
        ×
       
       
        3
       
      
     
    
    
     X=(X^1,X^2,X^3)\in \mathbb{R}^{3\times 3\times 3}
    
   
  X=(X1,X2,X3)∈R3×3×3,其中
  
   
    
     
      
       X
      
      
       i
      
     
     
      ∈
     
     
      
       R
      
      
       
        3
       
       
        ×
       
       
        3
       
      
     
    
    
     X^{i}\in \mathbb{R}^{3\times 3}
    
   
  Xi∈R3×3,
  
   
    
     
      i
     
     
      ∈
     
     
      {
     
     
      1
     
     
      ,
     
     
      2
     
     
      ,
     
     
      3
     
     
      }
     
    
    
     i\in\{1,2,3\}
    
   
  i∈{1,2,3}。将
  
   
    
     
      X
     
    
    
     X
    
   
  X分别输入到三个不同的
  
   
    
     
      1
     
     
      ×
     
     
      1
     
    
    
     1\times 1
    
   
  1×1的卷积层中,并生成
  
   
    
     
      q
     
     
      u
     
     
      e
     
     
      r
     
     
      y
     
    
    
     \mathrm{query}
    
   
  query特征图
  
   
    
     
      Q
     
     
      ∈
     
     
      
       R
      
      
       
        3
       
       
        ×
       
       
        3
       
       
        ×
       
       
        3
       
      
     
    
    
     Q\in \mathbb{R}^{3\times 3\times 3}
    
   
  Q∈R3×3×3,
  
   
    
     
      k
     
     
      e
     
     
      y
     
    
    
     \mathrm{key}
    
   
  key特征图
  
   
    
     
      K
     
     
      ∈
     
     
      
       R
      
      
       
        3
       
       
        ×
       
       
        3
       
       
        ×
       
       
        3
       
      
     
    
    
     K\in \mathbb{R}^{3\times 3\times 3}
    
   
  K∈R3×3×3和
  
   
    
     
      v
     
     
      a
     
     
      l
     
     
      u
     
     
      e
     
    
    
     \mathrm{value}
    
   
  value特征图
  
   
    
     
      V
     
     
      ∈
     
     
      
       R
      
      
       
        3
       
       
        ×
       
       
        3
       
       
        ×
       
       
        3
       
      
     
    
    
     V\in \mathbb{R}^{3\times 3\times 3}
    
   
  V∈R3×3×3。生成
  
   
    
     
      Q
     
    
    
     Q
    
   
  Q具体的计算过程为,给定三个卷积核
  
   
    
     
      
       W
      
      
       
        q
       
       
        1
       
      
     
    
    
     W^{q1}
    
   
  Wq1,
  
   
    
     
      
       W
      
      
       
        q
       
       
        2
       
      
     
    
    
     W^{q2}
    
   
  Wq2和
  
   
    
     
      
       W
      
      
       
        q
       
       
        3
       
      
     
     
      ∈
     
     
      
       R
      
      
       
        1
       
       
        ×
       
       
        1
       
       
        ×
       
       
        3
       
      
     
    
    
     W^{q3}\in\mathbb{R}^{1\times1\times3}
    
   
  Wq3∈R1×1×3,并用这三个卷积核分别与
  
   
    
     
      X
     
    
    
     X
    
   
  X做卷积运算得到
  
   
    
     
      
       Q
      
      
       1
      
     
    
    
     Q^1
    
   
  Q1,
  
   
    
     
      
       Q
      
      
       2
      
     
    
    
     Q^2
    
   
  Q2和
  
   
    
     
      
       Q
      
      
       3
      
     
     
      ∈
     
     
      
       R
      
      
       
        3
       
       
        ×
       
       
        3
       
      
     
    
    
     Q^3\in \mathbb{R}^{3 \times 3}
    
   
  Q3∈R3×3,即
   
    
     
      
       {
      
      
       
        
         
          
           
            Q
           
           
            1
           
          
         
        
        
         
          
           
           
            =
           
           
            X
           
           
            ∗
           
           
            
             W
            
            
             
              q
             
             
              1
             
            
           
          
         
        
       
       
        
         
          
           
            Q
           
           
            2
           
          
         
        
        
         
          
           
           
            =
           
           
            X
           
           
            ∗
           
           
            
             W
            
            
             
              q
             
             
              2
             
            
           
          
         
        
       
       
        
         
          
           
            Q
           
           
            3
           
          
         
        
        
         
          
           
           
            =
           
           
            X
           
           
            ∗
           
           
            
             W
            
            
             
              q
             
             
              3
             
            
           
          
         
        
       
      
     
     
      \left\{\begin{aligned}Q^1&=X * W^{q1}\\Q^2&=X * W^{q2}\\Q^3&=X*W^{q3}\end{aligned}\right.
     
    
   ⎩⎪⎨⎪⎧​Q1Q2Q3​=X∗Wq1=X∗Wq2=X∗Wq3​其中
  
   
    
     
      ∗
     
    
    
     *
    
   
  ∗表示卷积运算符号。同理生成
  
   
    
     
      K
     
    
    
     K
    
   
  K和
  
   
    
     
      V
     
    
    
     V
    
   
  V的计算过程与
  
   
    
     
      Q
     
    
    
     Q
    
   
  Q的计算过程类似。然后再利用
  
   
    
     
      Q
     
    
    
     Q
    
   
  Q和
  
   
    
     
      K
     
    
    
     K
    
   
  K进行注意力分数的计算得到矩阵
  
   
    
     
      A
     
     
      ∈
     
     
      
       R
      
      
       
        3
       
       
        ×
       
       
        3
       
      
     
    
    
     A\in \mathbb{R}^{3 \times 3}
    
   
  A∈R3×3,其中矩阵
  
   
    
     
      A
     
    
    
     A
    
   
  A的元素
  
   
    
     
      
       a
      
      
       
        m
       
       
        l
       
      
     
    
    
     a_{ml}
    
   
  aml​的计算公式为
   
    
     
      
       
        a
       
       
        
         m
        
        
         l
        
       
      
      
       =
      
      
       
        Q
       
       
        m
       
      
      
       ∗
      
      
       
        K
       
       
        l
       
      
      
       ,
      
      
      
       m
      
      
       ∈
      
      
       {
      
      
       1
      
      
       ,
      
      
       2
      
      
       ,
      
      
       3
      
      
       }
      
      
       ,
      
      
       l
      
      
       ∈
      
      
       {
      
      
       1
      
      
       ,
      
      
       2
      
      
       ,
      
      
       3
      
      
       }
      
     
     
      a_{ml}=Q^m * K^l,\quad m \in \{1,2,3\},l\in \{1,2,3\}
     
    
   aml​=Qm∗Kl,m∈{1,2,3},l∈{1,2,3}再对矩阵
  
   
    
     
      A
     
    
    
     A
    
   
  A利用
  
   
    
     
      s
     
     
      o
     
     
      f
     
     
      t
     
     
      m
     
     
      a
     
     
      x
     
    
    
     \mathrm{softmax}
    
   
  softmax函数进行注意力分布的计算得到注意力分布矩阵
  
   
    
     
      S
     
     
      ∈
     
     
      
       R
      
      
       
        3
       
       
        ×
       
       
        3
       
      
     
    
    
     S\in \mathbb{R}^{3\times 3}
    
   
  S∈R3×3,其中矩阵
  
   
    
     
      S
     
    
    
     S
    
   
  S的元素
  
   
    
     
      
       s
      
      
       
        m
       
       
        l
       
      
     
    
    
     s_{ml}
    
   
  sml​的计算公式为
   
    
     
      
       
        s
       
       
        
         m
        
        
         l
        
       
      
      
       =
      
      
       
        
         exp
        
        
         ⁡
        
        
         (
        
        
         
          a
         
         
          
           m
          
          
           l
          
         
        
        
         )
        
       
       
        
         
          ∑
         
         
          
           i
          
          
           =
          
          
           j
          
         
         
          3
         
        
        
         exp
        
        
         ⁡
        
        
         (
        
        
         
          a
         
         
          
           m
          
          
           j
          
         
        
        
         )
        
       
      
      
       ,
      
      
      
       m
      
      
       ∈
      
      
       {
      
      
       1
      
      
       ,
      
      
       2
      
      
       ,
      
      
       3
      
      
       }
      
      
       ,
      
      
       l
      
      
       ∈
      
      
       {
      
      
       1
      
      
       ,
      
      
       2
      
      
       ,
      
      
       3
      
      
       }
      
     
     
      s_{ml}=\frac{\exp(a_{ml})}{\sum\limits_{i=j}^{3}\exp(a_{mj})},\quad m \in \{1,2,3\},l\in\{1,2,3\}
     
    
   sml​=i=j∑3​exp(amj​)exp(aml​)​,m∈{1,2,3},l∈{1,2,3}最后利用注意力分布矩阵
  
   
    
     
      S
     
    
    
     S
    
   
  S和
  
   
    
     
      v
     
     
      a
     
     
      l
     
     
      u
     
     
      e
     
    
    
     \mathrm{value}
    
   
  value特征图
  
   
    
     
      V
     
    
    
     V
    
   
  V得到最后的输出
  
   
    
     
      O
     
     
      =
     
     
      (
     
     
      
       O
      
      
       1
      
     
     
      ,
     
     
      
       O
      
      
       2
      
     
     
      ,
     
     
      
       O
      
      
       3
      
     
     
      )
     
     
      ∈
     
     
      
       R
      
      
       
        3
       
       
        ×
       
       
        3
       
       
        ×
       
       
        3
       
      
     
    
    
     O=(O^1,O^2,O^3)\in \mathbb{R}^{3\times 3\times 3}
    
   
  O=(O1,O2,O3)∈R3×3×3,即
   
    
     
      
       {
      
      
       
        
         
          
           
            O
           
           
            1
           
          
         
        
        
         
          
           
           
            =
           
           
            
             s
            
            
             11
            
           
           
            ⋅
           
           
            
             V
            
            
             1
            
           
           
            +
           
           
            
             s
            
            
             12
            
           
           
            ⋅
           
           
            
             V
            
            
             2
            
           
           
            +
           
           
            
             s
            
            
             13
            
           
           
            ⋅
           
           
            
             V
            
            
             3
            
           
          
         
        
       
       
        
         
          
           
            O
           
           
            2
           
          
         
        
        
         
          
           
           
            =
           
           
            
             s
            
            
             21
            
           
           
            ⋅
           
           
            
             V
            
            
             1
            
           
           
            +
           
           
            
             s
            
            
             22
            
           
           
            ⋅
           
           
            
             V
            
            
             2
            
           
           
            +
           
           
            
             s
            
            
             23
            
           
           
            ⋅
           
           
            
             V
            
            
             3
            
           
          
         
        
       
       
        
         
          
           
            O
           
           
            3
           
          
         
        
        
         
          
           
           
            =
           
           
            
             s
            
            
             31
            
           
           
            ⋅
           
           
            
             V
            
            
             1
            
           
           
            +
           
           
            
             s
            
            
             32
            
           
           
            ⋅
           
           
            
             V
            
            
             2
            
           
           
            +
           
           
            
             s
            
            
             33
            
           
           
            ⋅
           
           
            
             V
            
            
             3
            
           
          
         
        
       
      
     
     
      \left\{\begin{aligned}O^1&=s_{11}\cdot V^1+s_{12}\cdot V^2+s_{13}\cdot V^3\\O^2&=s_{21}\cdot V^1+s_{22}\cdot V^2+s_{23}\cdot V^3\\O^3&=s_{31}\cdot V^1+s_{32}\cdot V^2+s_{33}\cdot V^3\end{aligned}\right.
     
    
   ⎩⎪⎨⎪⎧​O1O2O3​=s11​⋅V1+s12​⋅V2+s13​⋅V3=s21​⋅V1+s22​⋅V2+s23​⋅V3=s31​⋅V1+s32​⋅V2+s33​⋅V3​

2.2 AttnGAN

     A
    
    
     t
    
    
     t
    
    
     n
    
    
     G
    
    
     A
    
    
     N
    
   
   
    \mathrm{AttnGAN}
   
  
 AttnGAN通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。
 
  
   
    
     A
    
    
     t
    
    
     t
    
    
     n
    
    
     G
    
    
     A
    
    
     N
    
   
   
    \mathrm{AttnGAN}
   
  
 AttnGAN中注意力机制的操作原理如下图所示。

 给定输入图像特征向量

     h
    
    
     =
    
    
     (
    
    
     
      h
     
     
      1
     
    
    
     ,
    
    
     
      h
     
     
      2
     
    
    
     ,
    
    
     
      h
     
     
      3
     
    
    
     ,
    
    
     
      h
     
     
      4
     
    
    
     )
    
    
     ∈
    
    
     
      R
     
     
      
       
        D
       
       
        ^
       
      
      
       ×
      
      
       4
      
     
    
   
   
    h=(h^1,h^2,h^3,h^4)\in\mathbb{R}^{\hat{D}\times 4}
   
  
 h=(h1,h2,h3,h4)∈RD^×4和词特征向量
 
  
   
    
     e
    
    
     =
    
    
     (
    
    
     
      e
     
     
      1
     
    
    
     ,
    
    
     
      e
     
     
      2
     
    
    
     ,
    
    
     
      e
     
     
      3
     
    
    
     ,
    
    
     
      e
     
     
      4
     
    
    
     )
    
   
   
    e=(e^1,e^2,e^3,e^4)
   
  
 e=(e1,e2,e3,e4),其中
 
  
   
    
     
      h
     
     
      i
     
    
    
     ∈
    
    
     
      R
     
     
      
       
        D
       
       
        ^
       
      
      
       ×
      
      
       1
      
     
    
   
   
    h^i\in \mathbb{R}^{\hat{D}\times 1}
   
  
 hi∈RD^×1,
 
  
   
    
     
      e
     
     
      i
     
    
    
     ∈
    
    
     
      R
     
     
      
       D
      
      
       ×
      
      
       1
      
     
    
   
   
    e^i\in \mathbb{R}^{D\times 1}
   
  
 ei∈RD×1,
 
  
   
    
     i
    
    
     ∈
    
    
     {
    
    
     1
    
    
     ,
    
    
     2
    
    
     ,
    
    
     3
    
    
     ,
    
    
     4
    
    
     }
    
   
   
    i\in \{1,2,3,4\}
   
  
 i∈{1,2,3,4}。首先利用矩阵
 
  
   
    
     W
    
   
   
    W
   
  
 W进行线性变换将词特征空间
 
  
   
    
     
      R
     
     
      D
     
    
   
   
    \mathbb{R}^{D}
   
  
 RD的向量转换成图像特征空间
 
  
   
    
     
      R
     
     
      
       D
      
      
       ^
      
     
    
   
   
    \mathbb{R}^{\hat{D}}
   
  
 RD^的向量,则有
  
   
    
     
      
       e
      
      
       ^
      
     
     
      =
     
     
      W
     
     
      ⋅
     
     
      e
     
     
      =
     
     
      (
     
     
      
       
        e
       
       
        ^
       
      
      
       1
      
     
     
      ,
     
     
      
       
        e
       
       
        ^
       
      
      
       2
      
     
     
      ,
     
     
      
       
        e
       
       
        ^
       
      
      
       3
      
     
     
      ,
     
     
      
       
        e
       
       
        ^
       
      
      
       4
      
     
     
      )
     
     
      ∈
     
     
      
       R
      
      
       
        
         D
        
        
         ^
        
       
       
        ×
       
       
        4
       
      
     
    
    
     \hat{e}=W\cdot e=(\hat{e}^1,\hat{e}^2,\hat{e}^3,\hat{e}^4)\in \mathbb{R}^{\hat{D}\times 4}
    
   
  e^=W⋅e=(e^1,e^2,e^3,e^4)∈RD^×4然后再利用转换后的词特征
 
  
   
    
     
      e
     
     
      ^
     
    
   
   
    \hat{e}
   
  
 e^与图像特征
 
  
   
    
     h
    
   
   
    h
   
  
 h进行注意力分数的计算得到注意力分数矩阵
 
  
   
    
     S
    
   
   
    S
   
  
 S,其中的分量
 
  
   
    
     
      s
     
     
      
       i
      
      
       j
      
     
    
   
   
    s_{ij}
   
  
 sij​的计算公式为
  
   
    
     
      
       s
      
      
       
        i
       
       
        j
       
      
     
     
      =
     
     
      (
     
     
      
       h
      
      
       i
      
     
     
      
       )
      
      
       ⊤
      
     
     
      ⋅
     
     
      
       
        e
       
       
        ^
       
      
      
       j
      
     
     
      ,
     
     
     
      i
     
     
      ∈
     
     
      {
     
     
      1
     
     
      ,
     
     
      2
     
     
      ,
     
     
      3
     
     
      ,
     
     
      4
     
     
      }
     
     
      ,
     
     
      j
     
     
      ∈
     
     
      {
     
     
      1
     
     
      ,
     
     
      2
     
     
      ,
     
     
      3
     
     
      ,
     
     
      4
     
     
      }
     
    
    
     s_{ij}=(h^i)^{\top}\cdot \hat{e}^j,\quad i\in \{1,2,3,4\},j\in\{1,2,3,4\}
    
   
  sij​=(hi)⊤⋅e^j,i∈{1,2,3,4},j∈{1,2,3,4} 再对矩阵
 
  
   
    
     S
    
   
   
    S
   
  
 S利用
 
  
   
    
     s
    
    
     o
    
    
     f
    
    
     t
    
    
     m
    
    
     a
    
    
     x
    
   
   
    \mathrm{softmax}
   
  
 softmax函数进行注意力分布的计算得到注意力分布矩阵
 
  
   
    
     β
    
    
     ∈
    
    
     
      R
     
     
      
       4
      
      
       ×
      
      
       4
      
     
    
   
   
    \beta\in \mathbb{R}^{4\times 4}
   
  
 β∈R4×4,其中矩阵
 
  
   
    
     β
    
   
   
    \beta
   
  
 β的元素
 
  
   
    
     
      β
     
     
      
       i
      
      
       j
      
     
    
   
   
    \beta_{ij}
   
  
 βij​的计算公式为
  
   
    
     
      
       β
      
      
       
        i
       
       
        j
       
      
     
     
      =
     
     
      
       
        exp
       
       
        ⁡
       
       
        (
       
       
        
         s
        
        
         
          i
         
         
          j
         
        
       
       
        )
       
      
      
       
        
         ∑
        
        
         
          k
         
         
          =
         
         
          1
         
        
        
         3
        
       
       
        exp
       
       
        ⁡
       
       
        (
       
       
        
         s
        
        
         
          i
         
         
          k
         
        
       
       
        )
       
      
     
     
      ,
     
     
     
      i
     
     
      ∈
     
     
      {
     
     
      1
     
     
      ,
     
     
      2
     
     
      ,
     
     
      3
     
     
      ,
     
     
      4
     
     
      }
     
     
      ,
     
     
      l
     
     
      ∈
     
     
      {
     
     
      1
     
     
      ,
     
     
      2
     
     
      ,
     
     
      3
     
     
      ,
     
     
      4
     
     
      }
     
    
    
     \beta_{ij}=\frac{\exp(s_{ij})}{\sum\limits_{k=1}^{3}\exp(s_{ik})},\quad i \in \{1,2,3,4\},l\in\{1,2,3,4\}
    
   
  βij​=k=1∑3​exp(sik​)exp(sij​)​,i∈{1,2,3,4},l∈{1,2,3,4}最后利用注意力分布矩阵
 
  
   
    
     β
    
   
   
    \beta
   
  
 β和图像特征
 
  
   
    
     h
    
   
   
    h
   
  
 h得到最后的输出
 
  
   
    
     o
    
    
     =
    
    
     (
    
    
     
      o
     
     
      1
     
    
    
     ,
    
    
     
      o
     
     
      2
     
    
    
     ,
    
    
     
      o
     
     
      3
     
    
    
     ,
    
    
     
      o
     
     
      4
     
    
    
     )
    
    
     ∈
    
    
     
      R
     
     
      
       
        D
       
       
        ^
       
      
      
       ×
      
      
       4
      
     
    
   
   
    o=(o^1,o^2,o^3,o^4)\in \mathbb{R}^{\hat{D}\times 4}
   
  
 o=(o1,o2,o3,o4)∈RD^×4,即
  
   
    
     
      {
     
     
      
       
        
         
          
           o
          
          
           1
          
         
        
       
       
        
         
          
          
           =
          
          
           
            β
           
           
            11
           
          
          
           ⋅
          
          
           
            h
           
           
            1
           
          
          
           +
          
          
           
            β
           
           
            12
           
          
          
           ⋅
          
          
           
            h
           
           
            2
           
          
          
           +
          
          
           
            β
           
           
            13
           
          
          
           ⋅
          
          
           
            h
           
           
            3
           
          
          
           +
          
          
           
            β
           
           
            14
           
          
          
           ⋅
          
          
           
            h
           
           
            4
           
          
         
        
       
      
      
       
        
         
          
           o
          
          
           2
          
         
        
       
       
        
         
          
          
           =
          
          
           
            β
           
           
            21
           
          
          
           ⋅
          
          
           
            h
           
           
            1
           
          
          
           +
          
          
           
            β
           
           
            22
           
          
          
           ⋅
          
          
           
            h
           
           
            2
           
          
          
           +
          
          
           
            β
           
           
            23
           
          
          
           ⋅
          
          
           
            h
           
           
            3
           
          
          
           +
          
          
           
            β
           
           
            24
           
          
          
           ⋅
          
          
           
            h
           
           
            4
           
          
         
        
       
      
      
       
        
         
          
           o
          
          
           3
          
         
        
       
       
        
         
          
          
           =
          
          
           
            β
           
           
            31
           
          
          
           ⋅
          
          
           
            h
           
           
            1
           
          
          
           +
          
          
           
            β
           
           
            32
           
          
          
           ⋅
          
          
           
            h
           
           
            2
           
          
          
           +
          
          
           
            β
           
           
            33
           
          
          
           ⋅
          
          
           
            h
           
           
            3
           
          
          
           +
          
          
           
            β
           
           
            34
           
          
          
           ⋅
          
          
           
            h
           
           
            4
           
          
         
        
       
      
      
       
        
         
          
           o
          
          
           4
          
         
        
       
       
        
         
          
          
           =
          
          
           
            β
           
           
            41
           
          
          
           ⋅
          
          
           
            h
           
           
            1
           
          
          
           +
          
          
           
            β
           
           
            42
           
          
          
           ⋅
          
          
           
            h
           
           
            2
           
          
          
           +
          
          
           
            β
           
           
            43
           
          
          
           ⋅
          
          
           
            h
           
           
            3
           
          
          
           +
          
          
           
            β
           
           
            44
           
          
          
           ⋅
          
          
           
            h
           
           
            4
           
          
         
        
       
      
     
    
    
     \left\{\begin{aligned}o^1&=\beta_{11}\cdot h^1+\beta_{12}\cdot h^2+\beta_{13}\cdot h^3+\beta_{14}\cdot h^4\\o^2&=\beta_{21}\cdot h^1+\beta_{22}\cdot h^2+\beta_{23}\cdot h^3+\beta_{24}\cdot h^4\\o^3&=\beta_{31}\cdot h^1+\beta_{32}\cdot h^2+\beta_{33}\cdot h^3+\beta_{34}\cdot h^4\\o^4&=\beta_{41}\cdot h^1+\beta_{42}\cdot h^2+\beta_{43}\cdot h^3+\beta_{44}\cdot h^4\end{aligned}\right.
    
   
  ⎩⎪⎪⎪⎪⎨⎪⎪⎪⎪⎧​o1o2o3o4​=β11​⋅h1+β12​⋅h2+β13​⋅h3+β14​⋅h4=β21​⋅h1+β22​⋅h2+β23​⋅h3+β24​⋅h4=β31​⋅h1+β32​⋅h2+β33​⋅h3+β34​⋅h4=β41​⋅h1+β42​⋅h2+β43​⋅h3+β44​⋅h4​

3 Vision Transformer

本节主要详细介绍

     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer的工作原理,3.1节是关于
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer的整体框架,3.2节是关于
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
    
      
    
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer\text{ }Encoder}
   
  
 Transformer Encoder的内部操作细节。对于
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
    
      
    
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer\text{ }Encoder}
   
  
 Transformer Encoder中
 
  
   
    
     M
    
    
     u
    
    
     l
    
    
     t
    
    
     i
    
   
   
    \mathrm{Multi}
   
  
 Multi-
 
  
   
    
     H
    
    
     e
    
    
     a
    
    
     d
    
    
      
    
    
     A
    
    
     t
    
    
     t
    
    
     e
    
    
     n
    
    
     t
    
    
     i
    
    
     o
    
    
     n
    
   
   
    \mathrm{Head\text{ }Attention}
   
  
 Head Attention的原理本文不会赘述,具体想了解的可以参考上一篇文章《Transformer详解(附代码)》中相关原理的介绍。不难发现,不管是自然语言处理中的
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer}
   
  
 Transformer,还是计算机视觉中图像生成的
 
  
   
    
     S
    
    
     A
    
    
     G
    
    
     A
    
    
     N
    
   
   
    \mathrm{SAGAN}
   
  
 SAGAN,以及文本生成图像的
 
  
   
    
     A
    
    
     t
    
    
     t
    
    
     n
    
    
     G
    
    
     A
    
    
     N
    
   
   
    \mathrm{AttnGAN}
   
  
 AttnGAN,它们核心模块中注意力机制的主要目的就是求出注意力分布。

3.1 Vision Transformer整体框架

如果下图所示为

     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer的整体框架以及相应的训练流程
  • 给定一张图片 X ∈ R 3 n × 3 n X\in \mathbb{R}^{3n\times 3n} X∈R3n×3n,并将它分割成 9 9 9个 p a t c h \mathrm{patch} patch分别为 x 1 , ⋯   , x 9 ∈ R n × n x^1,\cdots,x^9\in\mathbb{R}^{n\times n} x1,⋯,x9∈Rn×n。然后再将这个 9 9 9个 p a t c h \mathrm{patch} patch拉平,则有 x 1 , ⋯   , x 9 ∈ R n 2 x^1,\cdots,x^9\in\mathbb{R}^{n^2} x1,⋯,x9∈Rn2
  • 利用矩阵 W ∈ R l × n 2 W\in \mathbb{R}^{l \times n^2} W∈Rl×n2将拉平后的向量 x i ∈ R n 2 , i ∈ { 1 , ⋯   , 9 } x^i\in\mathbb{R}^{n^2},i\in{1,\cdots,9} xi∈Rn2,i∈{1,⋯,9}经过线性变换得到图像编码向量 z i ∈ R l , i ∈ { 1 , ⋯   , 9 } z^i\in \mathbb{R}^{l},i\in{1,\cdots,9} zi∈Rl,i∈{1,⋯,9},具体的计算公式为 z i = W ⋅ x i , i ∈ { 1 , ⋯ 9 } z^i = W\cdot x^i,\quad i\in{1,\cdots9} zi=W⋅xi,i∈{1,⋯9}
  • 然后将图像编码向量 z i , i ∈ { 1 , ⋅ , 9 } z^{i},i\in{1,\cdot,9} zi,i∈{1,⋅,9}和类编码向量 z 0 z^0 z0分别与对应的位置编进行加和得到输入编码向量,则有 z i + p i ∈ R l , i ∈ { 0 , ⋯ 9 } z^{i}+p^{i}\in\mathbb{R}^l,\quad i\in{0,\cdots 9} zi+pi∈Rl,i∈{0,⋯9}
  • 接着将输入编码向量输入到 V i s i o n T r a n s f o r m e r E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder中得到对应的输出 o i ∈ R l , i ∈ { 0 , ⋯   , 9 } o^i\in \mathbb{R}^l,i\in{0,\cdots,9} oi∈Rl,i∈{0,⋯,9}
  • 最后将类编码向量 o 0 o^0 o0输入全连接神经网络中 M L P \mathrm{MLP} MLP得到类别预测向量 y ^ ∈ R c \hat{y}\in\mathbb{R}^c y^​∈Rc,并与真实类别向量 y ∈ R c y\in\mathbb{R}^c y∈Rc计算交叉熵损失得到损失值 l o s s loss loss,利用优化算法更新模型的权重参数

注意事项: 看到这里可能会有一个疑问为什么预测类别的时候只用到了类别编码向量

      o
     
     
      0
     
    
   
   
    o^0
   
  
 o0,
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
    
      
    
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer\text{ }Encoder}
   
  
 Vision Transformer Encoder其它的输出为什么没有输入到
 
  
   
    
     M
    
    
     L
    
    
     P
    
   
   
    \mathrm{MLP}
   
  
 MLP中?为了回答这个问题,我们令函数
 
  
   
    
     
      f
     
     
      0
     
    
    
     (
    
    
     ⋅
    
    
     )
    
   
   
    f_0(\cdot)
   
  
 f0​(⋅)为
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
    
      
    
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer\text{ }Encoder}
   
  
 Vision Transformer Encoder,则类编码向量
 
  
   
    
     
      o
     
     
      0
     
    
   
   
    o^{0}
   
  
 o0可以表示为
  
   
    
     
      
       o
      
      
       0
      
     
     
      =
     
     
      
       f
      
      
       0
      
     
     
      (
     
     
      
       z
      
      
       0
      
     
     
      +
     
     
      
       p
      
      
       0
      
     
     
      ,
     
     
      ⋯
      
     
      ,
     
     
      
       z
      
      
       9
      
     
     
      +
     
     
      
       p
      
      
       9
      
     
     
      )
     
    
    
     o^0=f_0(z^0+p^0,\cdots,z^9+p^9)
    
   
  o0=f0​(z0+p0,⋯,z9+p9)由上公式可以发现,类编码向量
 
  
   
    
     
      o
     
     
      0
     
    
   
   
    o^{0}
   
  
 o0是属于高层特征,其实它综合了所有的图像编码信息,所以可以用它来进行分类,这个可以类比在卷积神经网络中最后的类别输出向量其实就是一层层卷积得到的高层特征。

3.2 Transformer Encoder操作原理

如下图所示分别为

     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
    
      
    
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer\text{ }Encoder}
   
  
 Vision Transformer Encoder模型结构图和原始
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
    
      
    
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer\text{ }Encoder}
   
  
 Transformer Encoder的模型结构图。可以直观的发现
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
    
      
    
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer\text{ }Encoder}
   
  
 Vision Transformer Encoder和
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
    
      
    
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer\text{ }Encoder}
   
  
 Transformer Encoder都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的
 
  
   
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{ \text{ }Transformer}
   
  
  Transformer代码实例中,将以下两种
 
  
   
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Encoder}
   
  
 Encoder网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。

下图左半部分

      V
     
     
      i
     
     
      s
     
     
      i
     
     
      o
     
     
      n
     
     
       
     
     
      T
     
     
      r
     
     
      a
     
     
      n
     
     
      s
     
     
      f
     
     
      o
     
     
      r
     
     
      m
     
     
      e
     
     
      r
     
     
       
     
     
      E
     
     
      n
     
     
      c
     
     
      o
     
     
      d
     
     
      e
     
     
      r
     
    
    
     \mathrm{Vision\text{ }Transformer\text{ }Encoder}
    
   
  Vision Transformer Encoder具体的操作流程为
  • 给定输入编码矩阵 Z ∈ R l × n Z\in\mathbb{R}^{l\times n} Z∈Rl×n,首先将其进行层归一化得到 Z ′ ∈ R l × n Z^{\prime}\in\mathbb{R}^{l \times n} Z′∈Rl×n
  • 利用矩阵 W q , W k , W v ∈ R l × l W^{q},W^{k},W^{v}\in \mathbb{R}^{l\times l} Wq,Wk,Wv∈Rl×l对 Z ′ Z^{\prime} Z′进行线性变换得到矩阵 Q , K , W ∈ R l × n Q,K,W\in\mathbb{R}^{l\times n} Q,K,W∈Rl×n具体的计算过程为 { Q = W q ⋅ Z ′ K = W k ⋅ Z ′ V = W v ⋅ Z ′ \left{\begin{aligned}Q &= W^{q}\cdot Z^{\prime}\K&=W^{k}\cdot Z^{\prime}\V&=W^v \cdot Z^{\prime}\end{aligned}\right. ⎩⎪⎨⎪⎧​QKV​=Wq⋅Z′=Wk⋅Z′=Wv⋅Z′​再将这三个矩阵输入到 M u l t i \mathrm{Multi} Multi- H e a d A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention(该原理参考《Transformer详解(附代码)》)中得到矩阵 Z ′ ′ ∈ R l × n Z^{\prime\prime}\in \mathbb{R}^{l \times n} Z′′∈Rl×n将最原始的输入矩阵 Z Z Z与 Z ′ ′ Z^{\prime\prime} Z′′进行残差计算得到 Z + Z ′ ′ ∈ R l × n Z+Z^{\prime\prime}\in \mathbb{R}^{l\times n} Z+Z′′∈Rl×n
  • 将 Z + Z ′ ′ Z+Z^{\prime\prime} Z+Z′′进行第二次层归一化得到 Z ′ ′ ′ ∈ R l × n Z^{\prime\prime\prime}\in\mathbb{R}^{l\times n} Z′′′∈Rl×n,然后再将 Z ′ ′ ′ Z^{\prime\prime\prime} Z′′′输入到全连接神经网络中进行线性变换得到 Z ′ ′ ′ ′ ∈ R l × n Z^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n} Z′′′′∈Rl×n。最后将 Z + Z ′ ′ Z+Z^{\prime\prime} Z+Z′′与 Z ′ ′ ′ ′ Z^{\prime\prime\prime\prime} Z′′′′进行残差操作得到该 B l o c k \mathrm{Block} Block的输出 Z + Z ′ ′ + Z ′ ′ ′ ′ ∈ R l × n Z+Z^{\prime\prime}+Z^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n} Z+Z′′+Z′′′′∈Rl×n。一个 E n c o d e r \mathrm{Encoder} Encoder可以将 N N N个 B l o c k \mathrm{Block} Block进行堆叠,最后得到的输出为 O ∈ R l × n O\in\mathbb{R}^{l\times n} O∈Rl×n。

4 程序代码

     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer的代码示例如下所示。该代码是由上一篇《Transformer详解(附代码)》的代码的基础上改编而来。
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer的作者的本意就是想让在
 
  
   
    
     N
    
    
     L
    
    
     P
    
   
   
    \mathrm{NLP}
   
  
 NLP中的
 
  
   
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Transformer}
   
  
 Transformer模型架构做尽可能少的修改可以直接迁移到
 
  
   
    
     C
    
    
     V
    
   
   
    \mathrm{CV}
   
  
 CV中,所以以下程序尽可能保持作者的原意,并在代码实现了两种
 
  
   
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Encoder}
   
  
 Encoder的网络结构,即3.2节图片所示的两个网络结构,一种是最原始的
 
  
   
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Encoder}
   
  
 Encoder网络结构,一种是
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer论文里的
 
  
   
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Encoder}
   
  
 Encoder的网络结构。这里需要注意的是,
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer里并能没有
 
  
   
    
     D
    
    
     e
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Decoder}
   
  
 Decoder模块,所以不需要计算
 
  
   
    
     E
    
    
     n
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Encoder}
   
  
 Encoder和
 
  
   
    
     D
    
    
     e
    
    
     c
    
    
     o
    
    
     d
    
    
     e
    
    
     r
    
   
   
    \mathrm{Decoder}
   
  
 Decoder的交叉注意力分布,这就进一步给
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer的编程带来了简便。
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision\text{ }Transformer}
   
  
 Vision Transformer的开源代码的网址为https://github.com/lucidrains/vit-pytorch/tree/main/vit_pytorch。
import torch
import torch.nn as nn
import os
from einops import rearrange
from einops import repeat
from einops.layers.torch import Rearrange

definputs_deal(inputs):return inputs ifisinstance(inputs,tuple)else(inputs, inputs)classSelfAttention(nn.Module):def__init__(self, embed_size, heads):super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert(self.head_dim * heads == embed_size),"Embed size needs to be div by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)defforward(self, values, keys, query):
        N =query.shape[0]
        value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]# split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)# queries shape: (N, query_len, heads, heads_dim)# keys shape : (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)

        attention = torch.softmax(energy/(self.embed_size **(1/2)), dim=3)

        out = torch.einsum("nhql, nlhd->nqhd",[attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, heads_dim)# (N, query_len, heads, head_dim)

        out = self.fc_out(out)return out

classTransformerBlock(nn.Module):def__init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size))
        self.dropout = nn.Dropout(dropout)defforward(self, value, key, query, x, type_mode):if type_mode =='original':
            attention = self.attention(value, key, query)
            x = self.dropout(self.norm(attention + x))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm(forward + x))return out
        else:
            attention = self.attention(self.norm(value), self.norm(key), self.norm(query))
            x =self.dropout(attention + x)
            forward = self.feed_forward(self.norm(x))
            out = self.dropout(forward + x)return out

classTransformerEncoder(nn.Module):def__init__(
            self,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout =0,
            type_mode ='original'):super(TransformerEncoder, self).__init__()
        self.embed_size = embed_size
        self.type_mode = type_mode
        self.Query_Key_Value = nn.Linear(embed_size, embed_size *3, bias =False)

        self.layers = nn.ModuleList([
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,)for _ inrange(num_layers)])
        self.dropout = nn.Dropout(dropout)defforward(self, x):for layer in self.layers:
            QKV_list = self.Query_Key_Value(x).chunk(3, dim =-1)
            x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)return x

classVisionTransformer(nn.Module):def__init__(self, 
                image_size, 
                patch_size, 
                num_classes, 
                embed_size, 
                num_layers, 
                heads, 
                mlp_dim, 
                pool ='cls',
                channels =3,
                dropout =0,
                emb_dropout =0.1,
                type_mode ='vit'):super(VisionTransformer, self).__init__()
        img_h, img_w = inputs_deal(image_size)
        patch_h, patch_w = inputs_deal(patch_size)assert img_h % patch_h ==0and img_w % patch_w ==0,'Img dimensions can be divisible by the patch dimensions'

        num_patches =(img_h // patch_h)*(img_w // patch_w)

        patch_size = channels * patch_h * patch_w

        self.patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),
            nn.Linear(patch_size, embed_size, bias=False))

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches +1, embed_size))
        self.cls_token = nn.Parameter(torch.randn(1,1, embed_size))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = TransformerEncoder(embed_size, 
                                    num_layers, 
                                    heads, 
                                    mlp_dim,
                                    dropout)
        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, num_classes))defforward(self, img):
        x = self.patch_embedding(img)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token,'() n d ->b n d', b = b)
        x = torch.cat((cls_tokens, x), dim =1)
        x += self.pos_embedding[:,:(n +1)]
        x = self.dropout(x)
        x = self.transformer(x)
        x = x.mean(dim =1)if self.pool =='mean'else x[:,0]
        x = self.to_latent(x)return self.mlp_head(x)if __name__ =='__main__':
    vit = VisionTransformer(
            image_size =256,
            patch_size =16,
            num_classes =10,
            embed_size =256,
            num_layers =6,
            heads =8,
            mlp_dim =512,
            dropout =0.1,
            emb_dropout =0.1)
    img = torch.randn(3,3,256,256)
    pred = vit(img)print(pred)

以下代码是利用

     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision \text{ }Transformer}
   
  
 Vision Transformer网络结构训练一个分类
 
  
   
    
     m
    
    
     n
    
    
     i
    
    
     s
    
    
     t
    
   
   
    \mathrm{mnist}
   
  
 mnist数据集的主程序代码。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import VIT
import os

        
deftrain():
    batch_size =4
    device = torch.device('cuda'if torch.cuda.is_available()else'cpu')
    epoches =20
    mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)
    mnist_model = VIT.VisionTransformer(
        image_size =28,
        patch_size =7,
        num_classes =10,
        channels =1,
        embed_size =512,
        num_layers =1,
        heads =2,
        mlp_dim =1024,
        dropout =0,
        emb_dropout =0)
    loss_fn = nn.CrossEntropyLoss()
    mnist_model = mnist_model.to(device)
    opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)
    mnist_model.train()for epoch inrange(epoches):
        total_loss =0 
        corrects =0 
        num =0for batch_X, batch_Y in train_loader:
            batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
            opitimizer.zero_grad()
            outputs = mnist_model(batch_X)
            _, pred = torch.max(outputs.data,1)
            loss = loss_fn(outputs, batch_Y)
            loss.backward()
            opitimizer.step()
            total_loss += loss.item()
            corrects = torch.sum(pred == batch_Y.data)
            num += batch_size
            print(epoch, total_loss/float(num), corrects.item()/float(batch_size))if __name__ =='__main__':
    train()

训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个

     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision \text{ }Transformer}
   
  
 Vision Transformer模型真的是很烧硬件,跟训练一个普通的
 
  
   
    
     C
    
    
     N
    
    
     N
    
   
   
    \mathrm{CNN}
   
  
 CNN模型相比,训练一个
 
  
   
    
     V
    
    
     i
    
    
     s
    
    
     i
    
    
     o
    
    
     n
    
    
      
    
    
     T
    
    
     r
    
    
     a
    
    
     n
    
    
     s
    
    
     f
    
    
     o
    
    
     r
    
    
     m
    
    
     e
    
    
     r
    
   
   
    \mathrm{Vision \text{ }Transformer}
   
  
 Vision Transformer模型更加耗时耗力。


本文转载自: https://blog.csdn.net/qq_38406029/article/details/122157116
版权归原作者 鬼道2022 所有, 如有侵权,请联系我们删除。

“Vision Transformer详解(附代码)”的评论:

还没有评论