0


Swin-transformer详解

前言

这篇论文提出了一个新的 Vision Transformer 叫做 Swin Transformer,它可以被用来作为一个计算机视觉领域一个通用的骨干网络.但是直接把Transformer从 NLP 用到 Vision 是有一些挑战的,这个挑战主要来自于两个方面

一个就是尺度上的问题。因为比如说现在有一张街景的图片,里面有很多车和行人,里面的物体都大大小小,那这时候代表同样一个语义的词,比如说行人或者汽车就有非常不同的尺寸,这种现象在 NLP 中就没有

另外一个挑战是图像的 resolution太大了,如果要以像素点作为基本单位的话,序列的长度就变得高不可攀,所以说之前的工作要么就是用后续的特征图来当做Transformer的输入,要么就是把图片打成 patch 减少这个图片的 resolution,要么就是把图片画成一个一个的小窗口,然后在窗口里面去做自注意力,所有的这些方法都是为了减少序列长度.

基于这两个挑战,本文的作者就提出了 hierarchical Transformer,它的特征是通过一种叫做移动窗口的方式来学习的

移动窗口的好处:不仅带来了更大的效率,因为跟之前的工作一样,现在自注意力是在窗口内算的,所以这个序列的长度大大的降低了;同时通过 shifting 移动的这个操作,能够让相邻的两个窗口之间有了交互,所以上下层之间就可以有 cross-window connection,从而变相的达到了一种全局建模的能力。然后作者说这种层级式的结构不仅非常灵活,可以提供各个尺度的特征信息,同时因为自注意力是在小窗口之内算的,所以说它的计算复杂度是随着图像大小而线性增长,而不是平方级增长,这其实也为作者之后提出 Swin V2 铺平了道路,从而让他们可以在特别大的分辨率上去预训练模型

论文:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
官方开源代码地址:https://github.com/microsoft/Swin-Transformer

网络架构

首先基于上图对Swin Transformer和之前的Vision Transformer做了下对比。

  • Vision Transformer就是把图片打成 patch,Vit中patch size 16x16的所以图中写着16x,也就意味着是16倍的下采样率,这也就意味着每一个 patch,也就是每一个 token,自始至终代表的尺寸都是一样的。每一层的Transformer block 看到token的尺寸都是16倍下采样率。虽然它可以通过这种全局的自注意力操作,达到全局的建模能力,但是它对多尺寸特征的把握就会弱一些。swin Transformer使用了类似卷积神经网络中的层次化构建方法(Hierarchical feature maps) 刚开始的下采样率是4倍,然后变成了8倍、16倍,之所以刚开始是4×的,是因为最开始的 patch 是4乘4大小的,一旦有了多尺寸的特征信息,有了这种4x、8x、16x的特征图。
  • 所以当Swin Transformer有了多尺寸的特征信息 输给一个 FPN,从而就可以去做检测了 扔给一个 UNET,然后就可以去做分割了 所以这就是作者在这篇论文里反复强调的,Swin Transformer是能够当做一个通用的骨干网络的,不光是能做图像分类,还能做密集预测性的任务

对于视觉任务,尤其是下游任务比如说检测和分割来说,多尺寸的特征是至关重要的,比如说对目标检测而言,运用最广的一个方法就是 FPN(a feature pyramid network:当有一个分层式的卷积神经网络之后,每一个卷积层出来的特征的 receptive field (感受野)是不一样的,能抓住物体不同尺寸的特征,从而能够很好的处理物体不同尺寸的问题。

接下来,简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图,主要的关键点有:

patch merging

,

W-MSA

,

SW_MSA

以及

transform block

的计算过程.首先我们通过下图梳理下前向过程

前向过程

  1. 假设说有一张224*224*3(ImageNet 标准尺寸)的输入图片 第一步就是像 ViT 那样把图片打成 patch,在 Swin Transformer 这篇论文里,它的 patch size 是4*4,而不是像 ViT 一样16*16,所以说它经过 patch partition 打成 patch 之后,得到图片的尺寸是56*56*4856就是224/4,因为 patch size 是4,向量的维度48,因为4*4*33 是图片的 RGB 通道
  2. 打完了 patch ,接下来就要做 Linear Embedding,也就是说要把向量的维度变成一个预先设置好的值,就是 Transformer 能够接受的值,在 Swin Transformer 的论文里把这个超参数设为 c,对于 Swin tiny 网络来说,也就是上图中画的网络总览图,它的 c 是96,所以经历完 Linear Embedding 之后,输入的尺寸就变成了56*56*96,前面的56*56就会拉直变成3136,变成了序列长度,后面的96就变成了每一个token向量的维度,其实 Patch Partition 和 Linear Embedding 就相当于是 ViT 里的Patch Projection 操作,而在代码里也是用一次卷积操作就完成了,
  3. 第一部分跟 ViT 其实还是没有区别的,但紧接着区别就来了
  4. 首先序列长度是3136,对于 ViT 来说,用 patch size 16*16,它的序列长度就只有196,是相对短很多的,这里的3136就太长了,是目前来说Transformer不能接受的序列长度,所以 Swin Transformer就引入了基于窗口的自注意力计算,每个窗口按照默认来说(M=7),都只有七七四十九个 patch,所以说序列长度就只有49就相当小了,这样就解决了计算复杂度的问题
  5. 所以也就是说, stage1中的swin transformer block 是基于窗口计算自注意力的,现在暂时先把 transformer block当成是一个黑盒,只关注输入和输出的维度,对于 Transformer 来说,如果不对它做更多约束的话,Transformer输入的序列长度是多少,输出的序列长度也是多少,它的输入输出的尺寸是不变的,所以说在 stage1 中经过两层Swin Transformer block 之后,输出还是56*56*96
  6. 到这其实 Swin Transformer的第一个阶段就走完了,也就是先过一个 Patch Projection 层,然后再过一些 Swin Transformer block,接下来如果想要有多尺寸的特征信息,就要构建一个层级式的 transformer,也就是说需要一个像卷积神经网络里一样,有一个类似于池化的操作

Patch Merging

这篇论文里作者就提出 Patch Merging 的操作,Patch Merging 其实在之前一些工作里也有用到,它很像 Pixel Shuffle 的上采样的一个反过程,Pixel Shuffle 是 lower level 任务中很常用的一个上采样方式

  • 假如有一个张量, Patch Merging 顾名思义就是把临近的小 patch 合并成一个大 patch,这样就可以起到下采样一个特征图的效果了
  • 这里因为是想下采样两倍,所以说在选点的时候是每隔一个点选一个,也就意味着说对于这个张量来说,如下图每次选的点都是同一种颜色
  • 如果原张量的维度是 h * w * c ,当然这里 c 没有画出来,经过这次采样之后就得到了4个张量,每个张量的大小是 h/2、w/2,它的尺寸都缩小了一倍
  • 现在把这四个张量在 c 的维度上拼接起来,张量的大小就变成了 h/2 * w/2 * 4c,相当于用空间上的维度换了更多的通道数
  • 假如有一个张量, Patch Merging 顾名思义就是把临近的小 patch 合并成一个大 patch,这样就可以起到下采样一个特征图的效果了

在这里插入图片描述

  • 通过这个操作,就把原来一个大的张量变小了,就像卷积神经网络里的池化操作一样,为了跟卷积神经网络那边保持一致(不论是 VGGNet 还是 ResNet,一般在池化操作降维之后,通道数都会翻倍,从128变成256,从256再变成512),所以这里也只想让他翻倍,而不是变成4倍,所以紧接着又再做了一次操作,就是在 c 的维度上用一个1乘1的卷积,把通道数降下来变成2c,通过这个操作就能把原来一个大小为 h*w*c 的张量变成 h/2 * w/2 *2c 的一个张量,也就是说空间大小减半,但是通道数乘2,这样就跟卷积神经网络完全对等起来了

W-MSA详解

窗口是怎样划分的

原图片会被平均的分成一些没有重叠的窗口,拿第一层之前的输入来举例,它的尺寸就是

56*56*96

,也就说有一个维度是

56*56

张量,然后把它切成一些不重叠的方格,也就是下图表示的方格

  • 每一个方格就是一个窗口,如果在VIT中这个窗口就是最小单元了然后做Multi-head Self-Attention.但在Swin—T 中这个窗口并不是最小的计算单元 最小的计算单元是 patch 也就意味着每一个小窗口里其实还有 m * m 个 patch,在 Swin Transformer 这篇论文里一般 m 默认为7。在上图中的窗口中再划分如下图7*7的patch

  • 原来大的整体特征图到底里面会有多少个窗口呢?其实也就是每条边56/78个窗口,也就是说一共会有8*8等于64个窗口,就是说会在这64个窗口里分别去算它们的自注意力(Windows Multi-head Self-Attention)

基于窗口的自注意力模式的计算复杂度

MSA和W-MSA两者的计算量具体差多少呢?原论文中有给出下面两个公式:

       Ω 
      
     
       ( 
      
     
       M 
      
     
       S 
      
     
       A 
      
     
       ) 
      
     
       = 
      
     
     
      
      
        4 
       
      
        h 
       
      
        w 
       
      
        C 
       
      
     
       2 
      
     
    
      + 
     
     
      
      
        2 
       
      
        ( 
       
      
        h 
       
      
        w 
       
      
        ) 
       
      
     
       2 
      
     
    
      C 
                  
    
      ( 
     
    
      1 
     
    
      ) 
     
    
   
     {Ω(MSA)=}{4hwC}^{2}+{2(hw)}^{2}C\, \, \, \, \, \, \, \, \, \, \, \, \, (1) 
    
   
 Ω(MSA)=4hwC2+2(hw)2C(1)

  
   
    
     
     
       Ω 
      
     
       ( 
      
     
       W 
      
     
       − 
      
     
       M 
      
     
       S 
      
     
       A 
      
     
       ) 
      
     
       = 
      
     
     
      
      
        4 
       
      
        h 
       
      
        w 
       
      
        C 
       
      
     
       2 
      
     
    
      + 
     
     
      
      
        2 
       
      
        M 
       
      
     
       2 
      
     
    
      h 
     
    
      w 
     
    
      C 
                  
    
      ( 
     
    
      2 
     
    
      ) 
     
    
   
     {Ω(W-MSA)=}{4hwC}^{2}+{2M}^{2}hwC\, \, \, \, \, \, \, \, \, \, \, \, \, (2) 
    
   
 Ω(W−MSA)=4hwC2+2M2hwC(2)
  • h代表feature map的高度
  • w代表feature map的宽度
  • C代表feature map的深度
  • M代表每个窗口(Windows)的大小

上面的公式就是基于Self-Attention计算得来的,Self-Attention的公式如下:

      A 
     
    
      t 
     
    
      t 
     
    
      e 
     
    
      n 
     
    
      t 
     
    
      i 
     
    
      o 
     
    
      n 
     
    
      ( 
     
    
      Q 
     
    
      , 
     
    
      K 
     
    
      , 
     
    
      V 
     
    
      ) 
     
    
      = 
     
    
      s 
     
    
      o 
     
    
      f 
     
    
      t 
     
    
      M 
     
    
      a 
     
    
      x 
     
    
      ( 
     
     
      
       
       
         Q 
        
       
         K 
        
       
      
        T 
       
      
      
       
       
         d 
        
       
         k 
        
       
      
     
    
      ) 
     
    
      V 
        
    
      ( 
     
    
      3 
     
    
      ) 
     
    
   
     Attention(Q,K,V)=softMax(\frac {{QK}^{T}} {\sqrt {{d}_{k}}})V \, \, \, (3) 
    
   
 Attention(Q,K,V)=softMax(dk​​QKT​)V(3)

如下图解释计算过程如下图

在这里插入图片描述

MSA计算量

对于feature map中的每个像素(或称作token,patch),都要通过

      w 
     
    
      q 
     
    
   
  
    {w}_{q} 
   
  
wq​,  
 
  
   
    
    
      w 
     
    
      k 
     
    
   
  
    {w}_{k} 
   
  
wk​,  
 
  
   
    
    
      w 
     
    
      v 
     
    
   
  
    {w}_{v} 
   
  
wv​,

生成对应的query(q),key(k)以及value(v) 这里假设q, k, v的向量长度与feature map的深度C保持一致。那么对应所有像素生成Q的过程如下式:

       Q 
      
      
      
        h 
       
      
        w 
       
      
        × 
       
      
        c 
       
      
     
    
      = 
     
     
      
      
        A 
       
       
       
         h 
        
       
         w 
        
       
         × 
        
       
         c 
        
       
      
      
     
    
      × 
     
     
     
       W 
      
     
       q 
      
      
      
        c 
       
      
        × 
       
      
        c 
       
      
        
    
      ( 
     
    
      4 
     
    
      ) 
     
    
   
     {Q}^{hw\times c}={{A}^{hw\times c}}_{}×{W}^{c\times c}_{q}\, \, \, (4) 
    
   
 Qhw×c=Ahw×c​×Wqc×c​(4)
  • 首先,得到q,k,v 步骤。相当于是用一个 h*w*c 的向量乘以一个 c*c 的系数矩阵,最后得到了 h*w*c。所以每一个计算的复杂度是 h*w*c^2,因为有三次操作,所以是 3h*w*c^2
  • 然后,算自注意力就是 h*w*c乘以 k 的转置,也就是 c*h*w,所以得到了 h*w*h*w,这个计算复杂度就是(h*w)^2*c
  • 接下来,自注意力矩阵和value的乘积的计算复杂度还是 (h*w)^2*c,所以现在就成了2*(h*w)^2*c
  • 最后一步,投射层也就是h*w*c乘以 c*c 变成了 h*w*c ,它的计算复杂度就又是 h*w*c^2
  • 最后合并起来就是最后的公式(1)
W-MSA计算量
  • 因为在每个窗口里算的还是多头自注意力,所以可以直接套用公式(1),只不过高度和宽度变化了,现在高度和宽度不再是 h * w,而是变成窗口有多大了,也就是 M*M,也就是说现在 h 变成了 M,w 也是 M,它的序列长度只有 M * M 这么大

  • 所以当把 M 值带入到公式(1)之后,就得到计算复杂度是 4 ( M C ) 2 + 2 ( M ) 4 C {4(MC)}^{2}+2{(M)}^{4}C 4(MC)2+2(M)4C,这个就是在一个窗口里算多头自注意力所需要的计算复杂度 那我们现在一共有 h/M * w/M 个窗口,现在用这么多个窗口乘以每个窗口所需要的计算复杂度就能得到公式(2)了

         h 
        
       
         M 
        
       
      
        × 
       
       
       
         w 
        
       
         M 
        
       
      
        × 
       
       
        
        
          ( 
         
        
          4 
         
        
          ( 
         
        
          M 
         
        
          C 
         
        
          ) 
         
        
       
         2 
        
       
      
        + 
       
       
        
        
          2 
         
        
          ( 
         
        
          M 
         
        
          ) 
         
        
       
         4 
        
       
      
        C 
       
      
        ) 
       
      
        = 
       
       
        
        
          4 
         
        
          h 
         
        
          w 
         
        
          C 
         
        
       
         2 
        
       
      
        + 
       
       
        
        
          2 
         
        
          M 
         
        
       
         2 
        
       
      
        h 
       
      
        w 
       
      
        C 
       
      
     
       \frac {h} {M}\times \frac {w} {M}\times {(4(MC)}^{2}+{2(M)}^{4}C)={4hwC}^{2}+{2M}^{2}hwC 
      
     
    

    Mh​×Mw​×(4(MC)2+2(M)4C)=4hwC2+2M2hwC

假设feature map的h、w都为112,M=7,C=128,采用W-MSA模块相比MSA模块能够节省约40124743680 FLOPs:

        4 
       
      
        h 
       
      
        w 
       
      
        C 
       
      
     
       2 
      
     
    
      + 
     
     
      
      
        2 
       
      
        M 
       
      
     
       2 
      
     
    
      h 
     
    
      w 
     
    
      C 
     
    
      = 
     
    
      2 
     
    
      × 
     
     
     
       112 
      
     
       4 
      
     
    
      × 
     
    
      128 
     
    
      − 
     
    
      2 
     
    
      × 
     
     
     
       7 
      
     
       2 
      
     
    
      × 
     
     
     
       112 
      
     
       2 
      
     
    
      × 
     
    
      128 
     
    
      = 
     
    
      40124743680 
     
    
   
     {4hwC}^{2}+{2M}^{2}hwC=2\times {112}^{4}\times 128-2\times {7}^{2}\times {112}^{2}\times 128=40124743680 
    
   
 4hwC2+2M2hwC=2×1124×128−2×72×1122×128=40124743680

SW-MSA

前面有说,采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。如下图所示,左侧使用的是刚刚讲的W-MSA(假设是第L层),那么根据之前介绍的W-MSA和SW-MSA是成对使用的,那么第L+1层使用的就是SW-MSA(右侧图)。根据左右两幅图对比能够发现窗口(Windows)发生了偏移(可以理解成窗口从左上角分别向右侧和下方各偏移了

     ⌊ 
    
    
    
      M 
     
    
      2 
     
    
   
     ⌋ 
    
   
  
    \left \lfloor \frac {M} {2} \right \rfloor 
   
  
⌊2M​⌋个像素)。看下偏移后的窗口(右侧图),比如对于第一行第2列的2x4的窗口,它能够使第L层的第一排的两个窗口信息进行交流。再比如,第二行第二列的4x4的窗口,他能够使第L层的四个窗口信息进行交流,其他的同理。那么这就解决了不同窗口之间无法进行信息交流的问题。

根据上图,可以发现通过将窗口进行偏移后,由原来的4个窗口变成9个窗口了。后面又要对每个窗口内部进行MSA,这样做感觉又变麻烦了。为了解决这个麻烦,作者又提出而了Efficient batch computation for shifted configuration,一种更加高效的计算方法。如原论文中的示意图。

看了论文中的图还是不能理解它是如何循环移位的,然后我自己画了个示意图如下

在这里插入图片描述

高效的批量计算

作者通过这种巧妙的循环位移的方式和巧妙设计的掩码模板,从而实现了只需要一次前向过程,就能把所有需要的自注意力值都算出来,而且只需要计算4个窗口,也就是说窗口的数量没有增加,计算复杂度也没有增加,非常高效的完成了这个任务
这也是Swin-transformer最精华的部分之一。不同的窗口对应的不同掩码可视化图如下:

图片来自于原作者

masked计算过程

假设下面这幅图是经过cyclic shift过的位置,变成了9个窗口,窗口中的数字代表窗口的编号,同一窗口中的元素或者是相邻的窗口是可以可以做自注意力的。下面的这个图把9个窗口从中间分成4个窗口来做自注意力。0中的数据是可以相互做self attention
其中1,2不是相邻。3,6也不相邻。4,5,7,8 也不相邻,所以我们不希望他们做self attention。

假设图片的宽高是14x14。

在这里插入图片描述

下面我们以窗口3,6为例子做self attention 计算讲解。

我们把窗口3,6中的的7x7=49个patch,每个patch其实就是一个向量,我们把窗口拉直(窗口内从左到右从上到下取元素)就会变成如下的列向量(图A),其中有28个长度的3号窗口元素,21个6号窗口元素,然后把向量转置(图B)进行相乘得到矩阵(图C)

在这里插入图片描述

其中C矩阵中

33

,

66

他们来自同一个窗口的计算结果。是我们想要的,

36

,

63

不是我们想要的,因为他们是不相邻的窗口计算的结果,所以作者基于此,设计出了如下图的掩码:

在这里插入图片描述

其中-100可以认为是一个很大的负数,因为在做self attention 的时候矩阵里面的值经过归一化都是很小的。然后让下图掩码D和矩阵C做加法,这样

36

63

就会变成很小的负数,然后再做softmax这种很小的数就变成了0,也意味着我把这两块区域内容都masked掉了。同样的,我们也可以用这个计算方法得到窗口1,2的mask。窗口4,5,7,8的mask。最终的结果就是作者给的如下这个mask图.

注意,在计算完后还要把数据给挪回到原来的位置上

详细的体系结构规范

下图是原论文中给出的关于不同Swin Transformer的配置,T(Tiny),S(Small),B(Base),L(Large)
其中假设所有体系结构的输入图像大小为224×224。“Concatn×n”表示一个补丁中n×n个相邻特征的patch。此操作导致特征图的降采样率为n。“96-d”表示输出维数为96的线性层

参考链接

https://www.bilibili.com/read/cv14877004
https://blog.csdn.net/qq_37541097/article/details/121119988


本文转载自: https://blog.csdn.net/BXD1314/article/details/129659124
版权归原作者 @左左@右右 所有, 如有侵权,请联系我们删除。

“Swin-transformer详解”的评论:

还没有评论