0


Swin Transformer原理详解篇

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊往期回顾:CV攻城狮入门VIT(vision transformer)之旅——近年超火的Transformer你再不了解就晚了! CV攻城狮入门VIT(vision transformer)之旅——VIT原理详解篇 CV攻城狮入门VIT(vision transformer)之旅——VIT代码实战篇

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

文章目录

Swin Transformer原理详解篇

写在前面

​  在前面我们已经很系统的介绍了Transformer的相关知识,从NLP中的Transformer开始讲起,然后以此为基础详细的介绍了CV邻域VIT模型的原理和代码,对此部分不了解的可先点击下述链接了解详情:

  1. CV攻城狮入门VIT(vision transformer)之旅——近年超火的Transformer你再不了解就晚了! 🍁🍁🍁
  2. CV攻城狮入门VIT(vision transformer)之旅——VIT原理详解篇 🍁🍁🍁
  3. CV攻城狮入门VIT(vision transformer)之旅——VIT代码实战篇 🍁🍁🍁

​  那么这篇文章将为大家来讲讲Swin Transformer模型,在叙述这个模型之前大家可以去这个网站https://paperswithcode.com 看看这几年各项任务的sota,几乎都采用了swin Transformer模型,这足以彰显此模型的强大之处!!!【注意这个网站的paper后有个s喔,否则你可能会进入什么奇怪的网站😥这奇怪网站的域名起的可太无语了…🤐】

​  准备好了嘛,我们这就发车。🚖🚖🚖

网络整体框架

​  在谈及网络的框架之前,我们先来看看这篇Swin Transformer和我们之前所介绍的VIT模型有什么区别,论文中给出了下图:

image-20220819162819591

​  从直观上来看,左图(a)貌似比较复杂,相应的右图(b)就显得简单多了。再看的仔细点,可以发现(a)可以用于classification(分类)、segmentation(分割)、detection(检测)等一系列下游任务,而(b)就只用于classification,也就是说此篇Swin Transformer是视觉领域中一个通用的方案,可以应用于多种任务。🌼🌼🌼

​  我们再来看看上图中的一些细节,很容易发现,在(b)中图片一直都是采用16倍的下采样,【要是你了解VIT的原理这里肯定就明白,在图像输入的时候会用一个卷积对原图进行16倍的下采样,后面尺寸一直没变】 而在(a)中先是进行4倍下采样,然后是8倍,接着是16倍,一直这样下去,直到达到你想要的下采样倍数。这里特征图尺寸不断变小,是不是感觉和卷积非常像呢,其实这里和卷积的思想是一样的,采用这种层次化的构建方法有利于实现检测和分割任务。【这里是怎么进行下采样的我们后文回详细讲解,大家现在其实就完全可以理解为卷积】 🌼🌼🌼

​  上图还体现了Swin Transformer很重要的一点,那就是窗口(windows)的概念。可以看到,(a)图中的特征图都被划分成了一个个的小窗口,Swin Transformer会将这些小窗口送入到Transformer模型中,这样做的好处是可以大幅减小计算量。【这部分是Swin Transformer的难点,也是重点,后文我也会详细的为大家介绍】🌼🌼🌼


​  介绍了Swin Transformer和ViT的一些区别,下面我们就来看看Swin Transformer的整体框架,如下图所示:

image-20220819171441397

​  图(a)为Swin Transformer的结构, 这里我们不解释每个结构的细节,而是看看经过这些结构维度的变化。首先假设输入图片尺寸为H×W×3,首先经过patch Partion将图片分成一个个patch,patch大小为4×4,则经过此步后一共会有

     H
    
    
     4
    
   
   
    ×
   
   
    
     W
    
    
     4
    
   
  
  
   \frac{H}{4}×\frac{W}{4}
  
 
4H​×4W​个patch,因图片有三个通道,故每个patch的尺寸为

 
  
   
    4
   
   
    ×
   
   
    4
   
   
    ×
   
   
    3
   
   
    =
   
   
    48
   
  
  
   4×4×3=48
  
 
4×4×3=48 ,即我们通过Patch Partion结构得到特征图尺寸为

 
  
   
    
     H
    
    
     4
    
   
   
    ×
   
   
    
     W
    
    
     4
    
   
   
    ×
   
   
    48
   
  
  
   \frac{H}{4}×\frac{W}{4}×48
  
 
4H​×4W​×48。接着我们会通过Linera Embedding层,这就是一个全连接层,会将刚刚

 
  
   
    
     H
    
    
     4
    
   
   
    ×
   
   
    
     W
    
    
     4
    
   
   
    ×
   
   
    48
   
  
  
   \frac{H}{4}×\frac{W}{4}×48
  
 
4H​×4W​×48的特征图映射为

 
  
   
    
     H
    
    
     4
    
   
   
    ×
   
   
    
     W
    
    
     4
    
   
   
    ×
   
   
    C
   
  
  
   \frac{H}{4}×\frac{W}{4}×C
  
 
4H​×4W​×C。**【如果你对ViT模型熟悉的话就会发现,这里基本是一样的。在ViT代码中这步操作是直接通过卷积实现的,Swin Transformer这部分代码同样是由一个卷积实现】**

​  现在我们得到的是

     H
    
    
     4
    
   
   
    ×
   
   
    
     W
    
    
     4
    
   
   
    ×
   
   
    C
   
  
  
   \frac{H}{4}×\frac{W}{4}×C
  
 
4H​×4W​×C的特征图,下面会通过Swin Transformer Block结构,这里我们可以先将其理解为ViT中的Transformer Encoder结构,经过这个结构后输出尺寸仍然为

 
  
   
    
     H
    
    
     4
    
   
   
    ×
   
   
    
     W
    
    
     4
    
   
   
    ×
   
   
    C
   
  
  
   \frac{H}{4}×\frac{W}{4}×C
  
 
4H​×4W​×C。我们注意到这个结构下面写了一个×2,表示我们会重复这个结构两次。图(b)中表示的就是重复两次的Swin Transformer Block结构,这两次结构是有一些区别的且总是成对出现,因此后面的Swin Transformer Block重复的次数总是2的倍数。**【注意上文所输入Swin Transformer Bolok的尺寸为
  
   
    
     
      
       H
      
      
       4
      
     
     
      ×
     
     
      
       W
      
      
       4
      
     
     
      ×
     
     
      C
     
    
    
     \frac{H}{4}×\frac{W}{4}×C
    
   
  4H​×4W​×C,但其实输入Transformer中的尺寸应是二维的向量,所以代码中我们会将前两个维度放一起,即维度变为
  
   
    
     
      (
     
     
      
       H
      
      
       4
      
     
     
      ∗
     
     
      
       W
      
      
       4
      
     
     
      )
     
     
      ×
     
     
      C
     
    
    
     (\frac{H}{4}*\frac{W}{4})×C
    
   
  (4H​∗4W​)×C,之后将其送入Transformer中。同样输出也是二维的向量,我们得到后再将其展开即好。这些都是在阅读代码后才能知道的细节,大家稍加注意一点就好】**

​  通过第一个Swin Transformer Block后,尺寸为

     H
    
    
     4
    
   
   
    ×
   
   
    
     W
    
    
     4
    
   
   
    ×
   
   
    C
   
  
  
   \frac{H}{4}×\frac{W}{4}×C
  
 
4H​×4W​×C,会将其送入Patch Merging。这层实现了将特征图分辨率减半,通道数翻倍的操作,是不是很像CNN中的卷积呢,其实这里实现的功能是和卷积一模一样的,但是实现的方式有所不同,后文会详细解释。通过这层后,输出的特征图尺寸变为了

 
  
   
    
     H
    
    
     8
    
   
   
    ×
   
   
    
     W
    
    
     8
    
   
   
    ×
   
   
    2
   
   
    C
   
  
  
   \frac{H}{8}×\frac{W}{8}×2C
  
 
8H​×8W​×2C。

​  接下来都是一些重复的结构了,相信大家通过我上文的描述也能理解了。那么下面就要为大家详细的介绍某些结构的细节了,接着往下看吧!!!🌻🌻🌻


网络结构细节🧨🧨🧨

Patch partition+Linear Embedding

​  为了保证网络结构的完整性,关于这个结构我再简单的提一下。其实在上文我也大概叙述了Patch partition+Linear Embedding是干什么的了,而且也说了在代码中是怎么实现的——通过一个卷积操作即可完成这两步。这里我不想再唠叨了,如果你还想知道更详细的信息,在我的VIT原理详解篇有关于此部分的详细描述,感兴趣的自己去看看吧!!!🥂🥂🥂

Patch Merging

​  按网络结构流程来讲,这部分应该是介绍Swin Transformer Block结构,但这是整篇文章最核心、最难理解也是内容最多的,所以我打算放在最后一节叙述,这样由易到难进行介绍大家可能会更好的接收。

​  其实啊,我上文已经说了经过Patch Merging会达到怎样的效果——特征图分辨率减半,通道数翻倍。那么Patch Merging具体做了什么呢?我们可以来看一下下图:

​  假设我们的输入是4×4大小单通道的特征图,首先我们会隔一个取一个小Patch组合在一起,最后4×4的特征图会行成4个2×2的特征图。【这部分看上图是很好理解的,但是在代码部分对Pathon语法不熟悉的可能觉得有点难理解,我会再下一篇代码解读中为大家详细讲讲代码是怎么实现的】接下来将4个Patch进行拼接,现在得到的特征图尺寸为2×2×4。然后会经过一个LN层,这里当然会改变特征图的值,我改变了一些颜色象征性的表示了一下,LN层后特征图尺寸不会改变,仍为2×2×4。最后会经过一个全连接层,将特征图尺寸由2×2×4变为2×2×2。到这里,就把Patch Merging的原理给介绍完了,大家可以看一下输入和输出的结果是不是实现了特征图分辨率减半,通道数翻倍呢?🏵🏵🏵

Swin Transformer Block✨✨✨

​这部分的结构如下图所示:

image-20220822104440562

​我们再来看一下ViT中的encoder结构,如下图所示:

image-20220822105029811

​  对照这两个图,你可能会发现结构基本是一样的。主要区别只在一个地方,即ViT Encoder绿框中的Multi-Head Attention和Swin Transformer Block紫框中的W-MSA(Windows Multi-Head Attention) \SW-MSA(Shifted Windows Multi-Head Attention)。Swin Transformer Block中的两个结构的区别也只在这里有所不同,下面我就重点来为大家讲讲W-MSA和SW-MSA。



W-MSA

​  下面先来看看W-MSA。何为W-MSA,即Windows Multi-Head Attention,它也是一个多头的自注意机制。它和传统的 Multi-Head Attention的区别就在于W-MSA会先将特征图分成一个个Windows,然后对每个Windows执行Multi-Head Attention操作,如下图所示:

​  这时候肯定就有人问了,采用W-MSA相比于传统的MSA有什么优势呢?——这样可以大大的减少模型计算量,下图为论文中给出的两者计算量的对比:

image-20220822111514616

​  其中,h和w分别表示特征图的高度和宽度,C代表特征图的通道数,M表示窗口的大小。

​  那么这个公式是怎么得来的呢,下面就来详细说说MSA和W-MSA的计算量。


MSA计算量

​  根据我们前面几篇文章的介绍,我们应该对单头自注意的公式很熟悉了,即

    A
   
   
    t
   
   
    t
   
   
    e
   
   
    n
   
   
    t
   
   
    i
   
   
    o
   
   
    n
   
   
    (
   
   
    Q
   
   
    ,
   
   
    K
   
   
    ,
   
   
    V
   
   
    )
   
   
    =
   
   
    S
   
   
    o
   
   
    f
   
   
    t
   
   
    M
   
   
    a
   
   
    x
   
   
    (
   
   
    
     
      Q
     
     
      
       K
      
      
       T
      
     
    
    
     
      
       d
      
      
       k
      
     
    
   
   
    )
   
   
    V
   
  
  
   Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d_k}})V
  
 
Attention(Q,K,V)=SoftMax(dk​​QKT​)V。我们先来根据这个公式一步步的推导MSA的计算量。

​  我们假设,特征图尺寸为h×w×C,第一步我们会通过乘对应的

     W
    
    
     q
    
   
   
    、
   
   
    
     W
    
    
     k
    
   
   
    、
   
   
    
     W
    
    
     v
    
   
  
  
   W_q、W_k、W_v
  
 
Wq​、Wk​、Wv​矩阵生成

 
  
   
    Q
   
   
    、
   
   
    K
   
   
    、
   
   
    V
   
  
  
   Q、K、V
  
 
Q、K、V,这里假设

 
  
   
    Q
   
   
    、
   
   
    K
   
   
    、
   
   
    V
   
  
  
   Q、K、V
  
 
Q、K、V的向量长度与特征图A是深度c是一致的。我们来来看看由A乘

 
  
   
    
     W
    
    
     q
    
   
  
  
   W_q
  
 
Wq​得到

 
  
   
    Q
   
  
  
   Q
  
 
Q所需的计算量,如下图所示:

​  A的维度为

    h
   
   
    w
   
   
    ×
   
   
    C
   
  
  
   hw×C
  
 
hw×C,

 
  
   
    
     W
    
    
     q
    
   
  
  
   W_q
  
 
Wq​维度为

 
  
   
    C
   
   
    ×
   
   
    C
   
  
  
   C×C
  
 
C×C,则它们相乘后结果

 
  
   
    Q
   
  
  
   Q
  
 
Q维度为

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    C
   
  
  
   hw×C
  
 
hw×C。再来看

 
  
   
    Q
   
  
  
   Q
  
 
Q中的每个红点(每个像素)都由

 
  
   
    A
   
  
  
   A
  
 
A中的一行C个绿点和

 
  
   
    
     W
    
    
     q
    
   
  
  
   W_q
  
 
Wq​中一列C个黄点相乘得到,即得到一个红点会进行相乘操作

 
  
   
    C
   
  
  
   C
  
 
C次,也即得到一个红点的计算量为

 
  
   
    C
   
  
  
   C
  
 
C。

 
  
   
    Q
   
  
  
   Q
  
 
Q中共有

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    C
   
  
  
   hw×C
  
 
hw×C个红点,即计算量为

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    C
   
   
    ×
   
   
    C
   
   
    =
   
   
    h
   
   
    w
   
   
    
     C
    
    
     2
    
   
  
  
   hw×C×C=hwC^2
  
 
hw×C×C=hwC2。同理,生成K、V所需的计算量都为

 
  
   
    h
   
   
    w
   
   
    
     C
    
    
     2
    
   
  
  
   hwC^2
  
 
hwC2,因此生成

 
  
   
    Q
   
   
    、
   
   
    K
   
   
    、
   
   
    V
   
  
  
   Q、K、V
  
 
Q、K、V的过程共需计算量为

 
  
   
    3
   
   
    h
   
   
    w
   
   
    
     C
    
    
     2
    
   
  
  
   3hwC^2
  
 
3hwC2。**【注:后文我不会再画图帮大家理解计算量是怎么得到的了,大家可按照我上文思路画图理解,或者我们可以根据上文得到矩阵乘法计算量的一般公式,即
  
   
    
     
      (
     
     
      a
     
     
      ×
     
     
      b
     
     
      )
     
     
       
     
     
      ×
     
     
       
     
     
      (
     
     
      b
     
     
      ×
     
     
      c
     
     
      )
     
    
    
     (a×b) \ × \ (b×c)
    
   
  (a×b) × (b×c)的计算量为
  
   
    
     
      a
     
     
      b
     
     
      c
     
    
    
     abc
    
   
  abc】**

​  接着来看

    Q
   
   
    
     K
    
    
     T
    
   
  
  
   QK^T
  
 
QKT所用的计算量。

 
  
   
    Q
   
  
  
   Q
  
 
Q的维度为

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    C
   
  
  
   hw×C
  
 
hw×C,

 
  
   
    
     K
    
    
     T
    
   
  
  
   K^T
  
 
KT的维度为为

 
  
   
    C
   
   
    ×
   
   
    h
   
   
    w
   
  
  
   C×hw
  
 
C×hw,相乘后

 
  
   
    Q
   
   
    
     K
    
    
     T
    
   
  
  
   QK^T
  
 
QKT维度为

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    h
   
   
    w
   
  
  
   hw×hw
  
 
hw×hw,则此步所用计算量为

 
  
   
    (
   
   
    h
   
   
    w
   
   
    
     )
    
    
     2
    
   
   
    C
   
  
  
   (hw)^2 C
  
 
(hw)2C。接下来会除

 
  
   
    
     
      (
     
     
      
       d
      
      
       k
      
     
     
      )
     
    
   
  
  
   \sqrt{(d_k)}
  
 
(dk​)​和进行SoftMax操作,这两步计算量较少,可忽略。

​  然后用

    Q
   
   
    
     K
    
    
     T
    
   
  
  
   QK^T
  
 
QKT和

 
  
   
    V
   
  
  
   V
  
 
V相乘。

 
  
   
    Q
   
   
    
     K
    
    
     T
    
   
  
  
   QK^T
  
 
QKT维度为

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    h
   
   
    w
   
  
  
   hw×hw
  
 
hw×hw,

 
  
   
    V
   
  
  
   V
  
 
V的维度为

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    C
   
  
  
   hw×C
  
 
hw×C,得到结果维度为

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    C
   
  
  
   hw×C
  
 
hw×C,则此步所用计算量为

 
  
   
    (
   
   
    h
   
   
    w
   
   
    
     )
    
    
     2
    
   
   
    C
   
  
  
   (hw)^2 C
  
 
(hw)2C。

​  那么到这里,单头自注意力机制所用计算量就介绍完了,为上述几个计算量的和,即

    3
   
   
    h
   
   
    w
   
   
    
     C
    
    
     2
    
   
   
    +
   
   
    (
   
   
    h
   
   
    w
   
   
    
     )
    
    
     2
    
   
   
    C
   
   
    +
   
   
    (
   
   
    h
   
   
    w
   
   
    
     )
    
    
     2
    
   
   
    C
   
   
    =
   
   
    3
   
   
    h
   
   
    w
   
   
    
     C
    
    
     2
    
   
   
    +
   
   
    2
   
   
    (
   
   
    h
   
   
    w
   
   
    
     )
    
    
     2
    
   
   
    C
   
  
  
   3hwC^2+(hw)^2C+(hw)^2C=3hwC^2+2(hw)^2C
  
 
3hwC2+(hw)2C+(hw)2C=3hwC2+2(hw)2C。

​  对应多头注意力(MSA)来说,其计算量和单头注意差别就在最后一步乘

     W
    
    
     o
    
   
  
  
   W^o
  
 
Wo矩阵过程。上一步得到的矩阵维度为

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    C
   
  
  
   hw×C
  
 
hw×C,

 
  
   
    
     W
    
    
     o
    
   
  
  
   W^o
  
 
Wo维度为

 
  
   
    C
   
   
    ×
   
   
    C
   
  
  
   C×C
  
 
C×C,则最后结果维度为

 
  
   
    h
   
   
    w
   
   
    ×
   
   
    C
   
  
  
   hw×C
  
 
hw×C。**【和输入时维度一致】**这一步所需的计算量为

 
  
   
    h
   
   
    w
   
   
    
     C
    
    
     2
    
   
  
  
   hwC^2
  
 
hwC2。

​  综上,MSA所用计算量为**

      3
     
     
      h
     
     
      w
     
     
      
       C
      
      
       2
      
     
     
      +
     
     
      2
     
     
      (
     
     
      h
     
     
      w
     
     
      
       )
      
      
       2
      
     
     
      C
     
     
      +
     
     
      h
     
     
      w
     
     
      
       C
      
      
       2
      
     
     
      =
     
     
      4
     
     
      h
     
     
      w
     
     
      
       C
      
      
       2
      
     
     
      +
     
     
      2
     
     
      (
     
     
      h
     
     
      w
     
     
      
       )
      
      
       2
      
     
     
      C
     
    
    
     3hwC^2+2(hw)^2C+hwC^2=4hwC^2+2(hw)^2C
    
   
  3hwC2+2(hw)2C+hwC2=4hwC2+2(hw)2C** 。

W-MSA计算量

​  W-MSA就是把特征图分成几个小窗口分别送入MSA。所以我们只需要计算每个小窗口的计算量,然后乘上窗口数量即可得到W-MAS的计算量。现假设每个窗口的宽高都为M,则一共有

     h
    
    
     M
    
   
   
    ×
   
   
    
     w
    
    
     M
    
   
  
  
   \frac{h}{M}×\frac{w}{M}
  
 
Mh​×Mw​个窗口。接下来就先要计算一个窗口的计算量,对于一个窗口来说,其实就是一个MSA,计算公式上文已经给出,我们只需要将上文的

 
  
   
    h
   
   
    和
   
   
    w
   
  
  
   h和w
  
 
h和w换成现在的

 
  
   
    M
   
   
    和
   
   
    M
   
  
  
   M 和 M
  
 
M和M即可求得一个窗口计算量,其为

 
  
   
    4
   
   
    
     M
    
    
     2
    
   
   
    
     C
    
    
     2
    
   
   
    +
   
   
    2
   
   
    (
   
   
    
     M
    
    
     2
    
   
   
    
     )
    
    
     2
    
   
   
    C
   
  
  
   4M^2C^2+2(M^2)^2C
  
 
4M2C2+2(M2)2C。而现在共有 

 
  
   
    
     h
    
    
     M
    
   
   
    ×
   
   
    
     w
    
    
     M
    
   
  
  
   \frac{h}{M}×\frac{w}{M}
  
 
Mh​×Mw​个窗口,则W-MSA的计算量为**
  
   
    
     
      
       h
      
      
       M
      
     
     
      ×
     
     
      
       w
      
      
       M
      
     
     
      ×
     
     
      (
     
     
      4
     
     
      
       M
      
      
       2
      
     
     
      
       C
      
      
       2
      
     
     
      +
     
     
      2
     
     
      (
     
     
      
       M
      
      
       2
      
     
     
      
       )
      
      
       2
      
     
     
      C
     
     
      )
     
     
      =
     
     
      4
     
     
      h
     
     
      w
     
     
      
       C
      
      
       2
      
     
     
      +
     
     
      2
     
     
      
       M
      
      
       2
      
     
     
      h
     
     
      w
     
     
      C
     
    
    
     \frac{h}{M}×\frac{w}{M}×(4M^2C^2+2(M^2)^2C)=4hwC^2+2M^2hwC
    
   
  Mh​×Mw​×(4M2C2+2(M2)2C)=4hwC2+2M2hwC**。

​  你仔细对比着两个公式,你会发现区别在第二项。MSA为

    2
   
   
    (
   
   
    h
   
   
    w
   
   
    
     )
    
    
     2
    
   
   
    C
   
  
  
   2(hw)^2C
  
 
2(hw)2C,W-MSA为

 
  
   
    2
   
   
    
     M
    
    
     2
    
   
   
    h
   
   
    w
   
   
    C
   
  
  
   2M^2hwC
  
 
2M2hwC。其中

 
  
   
    h
   
   
    w
   
  
  
   hw
  
 
hw为原图的高和宽,而

 
  
   
    M
   
  
  
   M
  
 
M为窗口的高和宽。显然

 
  
   
    M
   
   
    ≤
   
   
    h
   
   
    ,
   
   
    M
   
   
    ≤
   
   
    w
   
  
  
   M \le h,M \le w
  
 
M≤h,M≤w,所以W-MSA的计算量要少。我认为这样说大家还感觉不到计算量差别有多大,我举个例子大家来感受一下:假设特征图h=224,w=224,C=3,M=7,则使用W-MSA比使用MSA节省15091034112计算量,如下:

     2
    
    
     (
    
    
     h
    
    
     w
    
    
     
      )
     
     
      2
     
    
    
     C
    
    
     −
    
    
     2
    
    
     
      M
     
     
      2
     
    
    
     h
    
    
     w
    
    
     C
    
    
     =
    
    
     2
    
    
     ×
    
    
     22
    
    
     
      4
     
     
      4
     
    
    
     ×
    
    
     3
    
    
     −
    
    
     2
    
    
     ×
    
    
     7
    
    
     ×
    
    
     7
    
    
     ×
    
    
     22
    
    
     
      4
     
     
      2
     
    
    
     ×
    
    
     3
    
    
     =
    
    
     15091034112
    
   
   
    2(hw)^2C-2M^2hwC=2×224^4×3-2×7×7×224^2×3=15091034112
   
  
 2(hw)2C−2M2hwC=2×2244×3−2×7×7×2242×3=15091034112


SW-MSA

​  这部分我认为是整篇文章最核心的东西,这里我会详细的为大家解释解释。何为SW-MSA,即shifted Windows Multi-Head Attention。我们先来想想我们为何要使用SW-MSA,这是由于W-MSA将原始特征图分成一个个小窗口,然后分别送入MSA中,这会导致各个窗口之前没有任何的联系,都是独立的,这显然不是我们期望看到的。而SW-MSA的出现就是为了解决这个问题,SW-MSA具体是怎么设计的,我们来看论文中的图片解释:

image-20220822172547017

​  左图是W-MSA,右图是SW-MSA,它们分别出现在Swin Transformer Block的相邻两层中。若W-MSA在Layer1层使用,则SW-MSA在Layer1+1层使用。

​  再来看看SW-MSA做了什么?SW-MSA会重新划分窗口,即由上图左侧变成上图右侧,这样划分过后会形成9个大小形状不同的窗口,这样就解决了窗口直接无法进行信息传递的问题。我们以上图右侧第二行第二列中间4×4的窗口为例,这一个窗口结合了Layer1层中四个窗口的信息,是不是很巧妙呢!!!**【至于怎么得到SW-MSA划分的窗口的呢,很多博客中基本都是说将原始特征图从左上角分别向右侧和下侧移动

       ⌊
      
      
       
        M
       
       
        2
       
      
      
       ⌋
      
     
     
      =
     
     
      2
     
    
    
     \left\lfloor {\frac{{\rm{M}}}{2}} \right\rfloor = 2
    
   
  ⌊2M​⌋=2个像素,M为窗口大小。但是这样描述我还是不能很好的理解,这个划分窗格后面还有其它操作,代码中这些操作是在一起进行的,就一行代码,过程也很容易理解,所以这部分我会在代码详解篇再详细讲讲】**

​  划分好窗口后,就可以对每个窗口进行MSA了。但是你会发现每个窗口的尺寸都不一样,这样做MSA是很复杂的。一个很容易想到的方法就是将每个窗口都padding成原始窗口大小,但很明显这样做又增大了计算量。作者提出了一种非常巧妙的方法,即通过shift(移动)将划分后的窗口进行重组,为方便大家理解,作图如下:

​  首先,给9个窗口标上数字,这样容易展示移动后的结果。第①步将0 1 2这三个窗口移到最下面得到图①,然后将3 6 0这三个窗口移动到最右侧得到图②。我们来观察一下最后得到的图②,4单独可以构成一个窗口,5和3可以合并构成一个窗口,7和1可以合并构成一个窗口,8 6 2 0 也可以合并构成一个窗口。此时一共可以分成4个窗口,和没划分窗口前一样,但是现在的4个窗口就解决了原先窗口无法进行信息传递的问题,例如现在5 3构成的窗口融合了原始四个窗口的信息,即使原始的四个窗口之前有了联系。

​  但是这样做还存在一个严重的问题,就是会导致信息全乱套了。怎么说呢,我们可以以5 3 区域为例进行解释。5 3区域在原图上是不相邻的两个区域,先经过一系列操作后会将它们放在一起进行MSA操作,这显然是不合适的。因此为了防止这种问题,在使用过程中,我们会使用Msked掩码来隔绝不同区域的信息,具体是怎么操作的呢,如下图所示:【这里偷个懒,不想画图了,图片来自于B站霹雳吧啦Wz ,大家可以去看看他的视频,真的会收获很多】

​  这里的窗口大小是4×4的,进行MSA时,对于窗口中每个元素都会先生成

    Q
   
   
    、
   
   
    K
   
   
    、
   
   
    V
   
  
  
   Q、K、V
  
 
Q、K、V,然后计算每个像素之前的相关性,即计算attention socres。我们拿像素0为例,它得到

 
  
   
    
     q
    
    
     0
    
   
  
  
   q^0
  
 
q0后会和所有像素的

 
  
   
    
     k
    
    
     T
    
   
  
  
   k^T
  
 
kT进行相乘,得到16个attention scores。上图

 
  
   
    
     a
    
    
     
      0
     
     
      ,
     
     
      0
     
    
   
  
  
   a_{0,0}
  
 
a0,0​表示

 
  
   
    
     q
    
    
     0
    
   
  
  
   q^0
  
 
q0与

 
  
   
    
     
      k
     
     
      0
     
    
    
     T
    
   
  
  
   {k^0}^T
  
 
k0T相乘的结果,

 
  
   
    
     a
    
    
     
      0
     
     
      ,
     
     
      1
     
    
   
  
  
   a_{0,1}
  
 
a0,1​表示

 
  
   
    
     q
    
    
     0
    
   
  
  
   q^0
  
 
q0与

 
  
   
    
     
      k
     
     
      1
     
    
    
     T
    
   
  
  
   {k^1}^T
  
 
k1T相乘的结果,依此类推,共得到16个结果。此时我们不会直接进行SoftMax操作,而是先将像素0与区域3中所有像素匹配结果都减去100,如上图中的

 
  
   
    
     a
    
    
     
      0
     
     
      ,
     
     
      2
     
    
   
   
    、
   
   
    
     a
    
    
     
      0
     
     
      ,
     
     
      3
     
    
   
   
    、
   
   
    
     a
    
    
     
      0
     
     
      ,
     
     
      6
     
    
   
   
    、
   
   
    
     a
    
    
     
      0
     
     
      ,
     
     
      7
     
    
   
  
  
   a_{0,2}、a_{0,3}、a_{0,6}、a_{0,7}
  
 
a0,2​、a0,3​、a0,6​、a0,7​等等。这样操作后再进行SoftMax操作会将减去100的那些attension scores都变成0。**【本来attention socres的值较小,减去100后经过SoftMax基本为0】**经过这样的操作,像素0其实没有和区域3中的像素进行任何操作,即只相当于在区域5中进行了MSA。那么接下来对于像素1、2、3等都是同样的道理。这里我额外的提一下,上图为4×4的窗口,我们一共会生成多少个attention scores呢?其实很简单啦,一个像素会有16个,一共16个像素,所有一共会有16×16个attention scores,记住这个小点喔,后面会考滴🌲🌲🌲【**注意:在我们操作完后,还需要将移动后的窗口还原回去,关于这一点,我在代码详解篇也会阐述】**

​  最后,我也给出论文中关于此部分的图解,现在看看,是不是很好理解了呢🍜🍜🍜

image-20220822215543745

Relative Position Bias详解✨✨✨

​  上文已经介绍了Swin Transformer非常关键的几个结构,相信你看懂这几个部分,也就基本能看懂整个Swin Transformer的结构了。这一节来为大家介绍相对位置偏置(Relative Position Bias),这是什么呢,其实啊,这个相对位置偏置所起的作用就和Tranformer中的绝对位置编码和ViT的带参数的位置编码是一样的,不过作者通过实验证明在Swin Transformer中使用相对位置偏置的效果要更好,如下图所示:

image-20220822215937732

​  通过上图消融实验可以发现,实验相对位置偏置的效果最好。那么问题来了,这个相对位置偏置应该放在什么位置呢,不卖关子了,论文中是放在Self Attention的计算公式中,如下:

     A
    
    
     t
    
    
     t
    
    
     e
    
    
     n
    
    
     t
    
    
     i
    
    
     o
    
    
     n
    
    
     (
    
    
     Q
    
    
     ,
    
    
     K
    
    
     ,
    
    
     V
    
    
     )
    
    
     =
    
    
     S
    
    
     o
    
    
     f
    
    
     t
    
    
     M
    
    
     a
    
    
     x
    
    
     (
    
    
     
      
       Q
      
      
       
        K
       
       
        T
       
      
     
     
      
       d
      
     
    
    
     +
    
    
     B
    
    
     )
    
    
     V
    
   
   
    Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d}}+B)V
   
  
 Attention(Q,K,V)=SoftMax(d​QKT​+B)V

​  上述公式中的B即为相对位置偏置。我们先来讨论讨论B的维度应该是怎样的。既然B可以和

      Q
     
     
      
       K
      
      
       T
      
     
    
    
     
      d
     
    
   
  
  
   \frac{QK^T}{\sqrt{d}}
  
 
d​QKT​相加,那么B的维度一定是和

 
  
   
    Q
   
   
    
     K
    
    
     T
    
   
  
  
   QK^T
  
 
QKT维度一致的。

 
  
   
    Q
   
   
    
     K
    
    
     T
    
   
  
  
   QK^T
  
 
QKT代表的是attention scores,它表示的是像素之前的相关性,他的维度和窗口大小有关,对于一个4×4的窗口,其一共有16个像素,则

 
  
   
    Q
   
   
    
     K
    
    
     T
    
   
  
  
   QK^T
  
 
QKT的维度为16×16。是不是绝对很熟悉呢,这点我在上文介绍SW-MSA提到过喔。那我再来举一个例子,如下图所示:

image-20220822221813231
​  这个窗口的大小为7×7大小,那么将这个窗口送入MSA中,B的维度是多少呢?没错啦,就是49×49!!!🥂🥂🥂

​  这个B的维度搞清楚了后,下面介绍起来就容易多了。下文图片也是采用霹雳吧啦Wz 的,大家可以去关注看看喔!!!

​  下面以2×2的窗口为大家展示相对位置偏置是怎么得到的,首先明确一点,窗口大小为2×2,则B的维度为4×4。

image-20220822222804547

​  这里来解释一下上图。首先对于一个2×2的图,我们可以很容易的得到二维的绝对位置索引,如第一行第一列用(0,0)表示,第一行第二列用(0,1)表示。接着会以每个像素为基准,行列都减去其它位置的索引得到相对位置索引。例如以蓝色像素(0,0)为基准,先用(0,0)-(0,0)=(0,0),得到第一个相对位置索引,接着用(0,0)-(0,1)=(0,-1)得到第一行第二列的位置索引,同理可以得到其它相对位置索引。这个过程其实就模拟了q和k的匹配过程。我们用这种方法,会得到4个2×2的相对位置索引,然后我们将其按行展平并拼接在一起就得到了相对位置索引矩阵,它是4×4的。

​  上图最后的结构相对位置索引矩阵是二维的,我们想将其变成一维矩阵,然后根据相对位置索引去相对位置偏置表里去查找对应的参数,这里的相对位置偏置表是一个可学习的参数。一个很自然的想法就是把上图中的行和列相加,不就可以变成一维的了嘛。确实是这样没错啦,但是这样粗暴的相加会产生一定的问题,如第一行第二列元素(0,-1)行列相加后是-1,第一行第三列元素(-1,0)行列相加后也是-1,但是(0,-1)和(-1,0)明显代表了不同的相对位置,所以这么做是不合适的。

​  下面我们来看看作者是如何优雅的进行操作的吧!!!如下图所示:
image-20220822225318512
​  此时,我们会将上图得到的4×4的矩阵对照相对位置偏置表得到最终的B,其维度为4×4。如下图所示:
image-20220822225542393
​  注意:这个相对位置偏置表维度为(2M-1)×(2M-1),这个大家可以自己捣鼓捣鼓看看能不能推出了,不行的话评论区见吧!!!再强调一下,这个相对位置偏置表是一个可学习的参数。等等还有一点值得一提,对于固定尺寸的窗口我们的相对位置索引矩阵(relative position index)是相同的!!!🌼🌼🌼

模型参数

​  这里我就直接放图了,大家看的会更直观:

image-20220822230258420

​  主要有四个模型,参数各有差异。其中win.sz.表示窗口大小,可以看到每个模型的窗口大小都为7×7;dim表示通道数或向量长度;head表示MSA中head个数。下面我以Swin-T为例,画了其经过每个模块后的维度变化,大家可以参考,如下图所示:

swin_transformer

小结

​  随便写着写着也快7000字了,写这篇文章时中间隔了两天去湖北溜达了一圈,所以感觉写的不是很累,我感觉一些重要的知识点也都涉及到了,如果发现还有补充的点我会更新上去,总之希望大家都能够学明白Swin Transformer的原理吧!!!🍻🍻🍻

​  在下一篇我将为大家带来这部分的代码实战,在Swin Transformer模型构建中有的部分还是有点难理解的,下一篇我尽量的表述清楚吧。最后说一句老生常谈的话,代码你还是得自己多调试,才能真正的理解。🌾🌾🌾

参考链接

Swin Transformer论文精读【论文精读】 🍁🍁🍁

Swin-Transformer网络结构详解🍁🍁🍁

Swin Transformer从零详细解读🍁🍁🍁

Vision Transformer 超详细解读 (原理分析+代码解读) (十七)🍁🍁🍁

如若文章对你有所帮助,那就🛴🛴🛴

一键三连 (1).gif


本文转载自: https://blog.csdn.net/qq_47233366/article/details/128624093
版权归原作者 秃头小苏 所有, 如有侵权,请联系我们删除。

“Swin Transformer原理详解篇”的评论:

还没有评论