0


Swin Transformer代码阅读注释

在这里插入图片描述

Swin Transformer代码阅读注释

前言

上一篇博文以论文中的内容介绍了Swin Transformer的网络结构和一些细节。本篇博文将从官方代码中的

  1. swin_transformer.py

去详细介绍Swin Transformer结构,并补充代码中才有而论文中没有的细节。

|| 如果对 Swin Transformer 不了解建议先看论文介绍再看源码 ||
Swin Transformer介绍博客:论文阅读笔记:Swin Transformer

Swin Transformer

代码中实现的网络结构与论文中的结构如如下:
在这里插入图片描述
如上图所示,代码中使用

  1. PatchEmbed

来实现

  1. Patch Partition

+

  1. Linear Embedding

,使用

  1. BasicLayer

来实现

  1. Swin Transformer Block

+

  1. PatchMerging

,对于最后一个

  1. BasicLayer

不使用

  1. PatchMerging

来降采样。

Swin-T 的配置如下:
Swin-T参数配置表
在这里插入图片描述
网络结构介绍可看Swin Transformer介绍博客:论文阅读笔记:Swin Transformer
整体结构代码和注释如下(代码大部分和和 Vision Transformer 是一样):

  1. classSwinTransformer(nn.Module):def__init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
  2. embed_dim=96, depths=[2,2,6,2], num_heads=[3,6,12,24],
  3. window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
  4. drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
  5. norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
  6. use_checkpoint=False,**kwargs):'''
  7. img_size(int | tuple(int)): 输入图像尺寸. 默认: 224
  8. patch_size (int | tuple(int)): Patch尺寸. 默认: Swin-T参数配置表中 stage1中的96
  9. in_chans (int): 输入图像通道. 默认: 3
  10. num_classes (int): 分类数. 默认: 1000
  11. embed_dim (int): Patch embedding的输出通道. 默认: 96 (Swin-T参数配置表中 stage 1 中的 96-d)
  12. depths (tuple(int)):Swin Transformer Block 的个数. 默认:[2, 2, 6, 2] (Swin-T参数配置表中的[×2, ×2, ×6, ×2])
  13. num_heads (tuple(int)): 不同层 MSA 计算中的 head 数. 默认:[3, 6, 12, 24] (Swin-T参数配置表中的[head 3,head 6,head 12,head 24])
  14. window_size (int): W-MSA 和 SW-MSA 的 Window 尺寸. 默认: 7 (Swin-T参数配置表中的 “win.sz. 7×7”)
  15. mlp_ratio (float): 通过MLP的输出通道倍数. 默认: 4 (Swin-T参数配置表中的“dim 96”,“dim 192”,“dim 384”,“dim 768”可以看出)
  16. qkv_bias (bool): 使用 Linear 将输入映射到 qkv 时,Linear是否使用偏置. 默认: True
  17. qk_scale (float):qk缩放比例,如果是 None 则使用根号 dim_k 分之一. 默认: None
  18. drop_rate (float): dropout概率. 默认: 0
  19. attn_drop_rate (float): attention 中的 dropout 概率. 默认: 0
  20. drop_path_rate (float): attention 中的 droppath 概率. 默认: 0.1
  21. norm_layer (nn.Module): 归一化方式. 默认: nn.LayerNorm.
  22. ape (bool): 是否在 patch embedding 后使用绝对位置编码. 默认: False
  23. patch_norm (bool): 是否在 patch embedding 后使用归一化. 默认: True
  24. use_checkpoint (bool): 是否 checkpointing 节省内存. 默认: False
  25. '''super().__init__()
  26. self.num_classes = num_classes
  27. self.num_layers =len(depths)
  28. self.embed_dim = embed_dim
  29. self.ape = ape
  30. self.patch_norm = patch_norm
  31. '''
  32. 经过4个stage后的通道数(从96->768 即:96*2^(4-1)=768)
  33. '''
  34. self.num_features =int(embed_dim *2**(self.num_layers -1))
  35. self.mlp_ratio = mlp_ratio
  36. '''
  37. 将图片划分成没有重叠的多个patch
  38. PatchEmbed代码在下文中介绍
  39. patches_resolution = [img_size[0]//patch_size[0], img_size[1]//patch_size[1]] = [56,56]
  40. num_patches = patches_resolution[0] * patches_resolution[1] = 3136
  41. '''
  42. self.patch_embed = PatchEmbed(
  43. img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
  44. norm_layer=norm_layer if self.patch_norm elseNone)
  45. num_patches = self.patch_embed.num_patches
  46. patches_resolution = self.patch_embed.patches_resolution
  47. self.patches_resolution = patches_resolution
  48. '''
  49. 如果使用绝对位置编码则构建可学习的绝对位置编码参数:
  50. self.absolute_pos_embed : [1,3136,96]
  51. 默认不使用
  52. '''if self.ape:
  53. self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
  54. trunc_normal_(self.absolute_pos_embed, std=.02)'''
  55. pos_drop 以 drop_rate 概率进行 Dropout
  56. '''
  57. self.pos_drop = nn.Dropout(p=drop_rate)'''
  58. 构建首项为0,长度为depths(2+2+6+2=12)的等差数列,且最后一项小于drop_path_rate
  59. 也就是说 传入 BasicLayer 的 droppath 概率是递增的。
  60. 代码这里是让 drop_path_ratio 默认等于0.1
  61. 最后利用参数构建 depth(12) 层 BasicLayer 层
  62. BasicLayer 的代码在下文中介绍
  63. '''
  64. dpr =[x.item()for x in torch.linspace(0, drop_path_rate,sum(depths))]# stochastic depth decay rule# build layers
  65. self.layers = nn.ModuleList()for i_layer inrange(self.num_layers):
  66. layer = BasicLayer(dim=int(embed_dim *2** i_layer),#每个Basiclayer模块后通道数都翻倍'''
  67. 每个Basiclayer都进行了降采样
  68. 所以input_resolution每一次都要除以2
  69. '''
  70. input_resolution=(patches_resolution[0]//(2** i_layer),
  71. patches_resolution[1]//(2** i_layer)),
  72. depth=depths[i_layer],
  73. num_heads=num_heads[i_layer],
  74. window_size=window_size,
  75. mlp_ratio=self.mlp_ratio,
  76. qkv_bias=qkv_bias, qk_scale=qk_scale,
  77. drop=drop_rate, attn_drop=attn_drop_rate,
  78. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer +1])],
  79. norm_layer=norm_layer,'''
  80. 如果i_layer 是最后一层则不使用PatchMerging 来降采样
  81. '''
  82. downsample=PatchMerging if(i_layer < self.num_layers -1)elseNone,
  83. use_checkpoint=use_checkpoint)
  84. self.layers.append(layer)'''
  85. 进行归一化和平均池化
  86. 最后用一个Linear做预测head
  87. '''
  88. self.norm = norm_layer(self.num_features)
  89. self.avgpool = nn.AdaptiveAvgPool1d(1)
  90. self.head = nn.Linear(self.num_features, num_classes)if num_classes >0else nn.Identity()
  91. self.apply(self._init_weights)'''
  92. 初始化权重
  93. '''def_init_weights(self, m):ifisinstance(m, nn.Linear):
  94. trunc_normal_(m.weight, std=.02)ifisinstance(m, nn.Linear)and m.bias isnotNone:
  95. nn.init.constant_(m.bias,0)elifisinstance(m, nn.LayerNorm):
  96. nn.init.constant_(m.bias,0)
  97. nn.init.constant_(m.weight,1.0)defforward_features(self, x):'''
  98. 如图所示先进行patch embedding
  99. 如果使用绝对位置偏置就加上绝对位置编码
  100. '''
  101. x = self.patch_embed(x)if self.ape:
  102. x = x + self.absolute_pos_embed
  103. x = self.pos_drop(x)'''
  104. 循环执行Blocks
  105. '''for layer in self.layers:
  106. x = layer(x)'''
  107. 归一化并平均池化
  108. '''
  109. x = self.norm(x)# B L C
  110. x = self.avgpool(x.transpose(1,2))# B C 1
  111. x = torch.flatten(x,1)return x
  112. defforward(self, x):
  113. x = self.forward_features(x)'''
  114. 对swin transformer的特征提取进行预测
  115. '''
  116. x = self.head(x)return x

1 PatchEmbed

Swin Transformer中的

  1. PatchEmbed

模块和 VIT 中的 Linear Projection of Flattened Patches:

  1. PatchEmbed

模块差不多,可查看博文Vision Transformer(Pytorch版)代码阅读注释 查看,其主要思想是通过感受野大小等于步距大小的卷积来实现,与 VIT 不同的是其使用了

  1. nn.LayerNorm

  1. PatchEmbed

代码和注释如下:

  1. classPatchEmbed(nn.Module):""" Image to Patch Embedding
  2. Args:
  3. img_size (int): 图像尺寸. 默认: 224.
  4. patch_size (int): token尺寸. 默认: 4.
  5. in_chans (int): 图像通道. 默认: 3.
  6. embed_dim (int): patch embed通道. 默认: 96.
  7. norm_layer (nn.Module, optional): 归一化. Default: None
  8. """def__init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):super().__init__()'''
  9. self.image_size = (224,224)
  10. self.patch_size = (4,4)
  11. self.patches_resolution = [56,56]
  12. self.num_patches = 56*56=3136
  13. '''
  14. img_size = to_2tuple(img_size)
  15. patch_size = to_2tuple(patch_size)
  16. patches_resolution =[img_size[0]// patch_size[0], img_size[1]// patch_size[1]]
  17. self.img_size = img_size
  18. self.patch_size = patch_size
  19. self.patches_resolution = patches_resolution
  20. self.num_patches = patches_resolution[0]* patches_resolution[1]'''
  21. self.in_chans = 3
  22. self.embed_dim = 96
  23. '''
  24. self.in_chans = in_chans
  25. self.embed_dim = embed_dim
  26. '''
  27. self.proj = nn.Conv2d(3,96,(4,4),4)
  28. self.norm = nn.LayerNorm(96)
  29. '''
  30. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)if norm_layer isnotNone:
  31. self.norm = norm_layer(embed_dim)else:
  32. self.norm =Nonedefforward(self, x):
  33. B, C, H, W = x.shape
  34. # FIXME look at relaxing size constraintsassert H == self.img_size[0]and W == self.img_size[1], \
  35. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."'''
  36. self.proj(x):[B,3,224,224]->[B,96,56,56]
  37. flatten(2):[B,96,56,56]->[B,96,56*56]=[B,96,3136]
  38. transpose(1, 2):[B,96,3136]->[B,3136,96]
  39. self.norm(x):[B,3136,96]->[B,3136,96]
  40. '''
  41. x = self.proj(x).flatten(2).transpose(1,2)# B Ph*Pw Cif self.norm isnotNone:
  42. x = self.norm(x)return x

BasicLayer

代码中使用

  1. BasicLayer

来实现论文中的

  1. Swin Transformer Block

+

  1. PatchMerging

,对于最后一个

  1. BasicLayer

不使用

  1. PatchMerging

来降采样。

  1. BasicLayer

的代码和注释如下:

  1. classBasicLayer(nn.Module):""" A basic Swin Transformer layer for one stage.
  2. Args:
  3. dim (int): 输入特征图的通道数.
  4. input_resolution (tuple[int]): 输入特征图的分辨率大小.
  5. depth (int): SwinTransformerBlock的个数.
  6. num_heads (int): Muti-Head Self-Attention 中的head个数.
  7. window_size (int): window 大小.
  8. mlp_ratio (float): patch embedding通过MLP的通道倍数.
  9. qkv_bias (bool): 使用 Linear 将输入映射到 qkv 时,Linear是否使用偏置. 默认: True
  10. qk_scale (float):qk缩放比例,如果是 None 则使用根号 dim_k 分之一. 默认: None
  11. drop (float, optional): dropout概率. 默认: 0
  12. attn_drop (float, optional): attention 中的 dropout 概率. 默认: 0
  13. drop_path (float | tuple[float], optional): attention 中的 droppath 概率. 默认: 0.1
  14. norm_layer (nn.Module): 归一化方式. 默认: nn.LayerNorm.
  15. downsample (nn.Module | None, optional): 降采样层. 默认: None 代码使用PatchMerging
  16. use_checkpoint (bool): 是否 checkpointing 节省内存. 默认: False
  17. """def__init__(self, dim, input_resolution, depth, num_heads, window_size,
  18. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
  19. drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()
  20. self.dim = dim
  21. self.input_resolution = input_resolution
  22. self.depth = depth
  23. self.use_checkpoint = use_checkpoint
  24. '''
  25. 构建SwinTransformerBlock
  26. SwinTransformerBlock代码在下文介绍
  27. '''
  28. self.blocks = nn.ModuleList([
  29. SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
  30. num_heads=num_heads, window_size=window_size,'''
  31. 如果i是偶数,则表示是W-MSA,shift_size =0
  32. 如果i是奇数,则表示是SW-MSA,shift_size = window_size // 2
  33. '''
  34. shift_size=0if(i %2==0)else window_size //2,
  35. mlp_ratio=mlp_ratio,
  36. qkv_bias=qkv_bias, qk_scale=qk_scale,
  37. drop=drop, attn_drop=attn_drop,
  38. drop_path=drop_path[i]ifisinstance(drop_path,list)else drop_path,
  39. norm_layer=norm_layer)for i inrange(depth)])'''
  40. 使用PatchMerging进行降采样
  41. PatchMerging代码在下文介绍
  42. '''if downsample isnotNone:
  43. self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)else:
  44. self.downsample =Nonedefforward(self, x):for blk in self.blocks:if self.use_checkpoint:
  45. x = checkpoint.checkpoint(blk, x)else:
  46. x = blk(x)if self.downsample isnotNone:
  47. x = self.downsample(x)return x

SwinTransformerBlock

  1. SwinTransformerBlock

的结构如下:
在这里插入图片描述

Mlp

此处和 VIT 中

  1. MLP

一模一样,可查看Vision Transformer(Pytorch版)代码阅读注释 。代码也很简单,就不再做任何赘述了。

  1. Mlp

的代码如下:

  1. classMlp(nn.Module):def__init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()
  2. out_features = out_features or in_features
  3. hidden_features = hidden_features or in_features
  4. self.fc1 = nn.Linear(in_features, hidden_features)
  5. self.act = act_layer()
  6. self.fc2 = nn.Linear(hidden_features, out_features)
  7. self.drop = nn.Dropout(drop)defforward(self, x):
  8. x = self.fc1(x)
  9. x = self.act(x)
  10. x = self.drop(x)
  11. x = self.fc2(x)
  12. x = self.drop(x)return x

window_partition

  1. W-MSA

  1. SW-MSA

首先需要将特征图拆分成多个windows。
在这里插入图片描述

  1. window_partition

的代码和注释如下:

  1. defwindow_partition(x, window_size):"""
  2. Args:
  3. x: (B, H, W, C)
  4. window_size (int): window size
  5. Returns:
  6. windows: (num_windows*B, window_size, window_size, C)
  7. """
  8. B, H, W, C = x.shape
  9. '''
  10. [B, H, W, C] -> [BHW//(window_size*window_size), window_size, window_size, C]
  11. '''
  12. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  13. windows = x.permute(0,1,3,2,4,5).contiguous().view(-1, window_size, window_size, C)return windows

window_reverse

在对每一个windows进行

  1. WSA

计算以后需要将其还原成正常的特征图传入下一模块中。其实就是

  1. window_partition

的逆过程。

  1. window_reverse

的代码和注释如下:

  1. defwindow_reverse(windows, window_size, H, W):"""
  2. Args:
  3. windows: (num_windows*B, window_size, window_size, C)
  4. window_size (int): Window size
  5. H (int): Height of image
  6. W (int): Width of image
  7. Returns:
  8. x: (B, H, W, C)
  9. """
  10. B =int(windows.shape[0]/(H * W / window_size / window_size))
  11. x = windows.view(B, H // window_size, W // window_size, window_size, window_size,-1)
  12. x = x.permute(0,1,3,2,4,5).contiguous().view(B, H, W,-1)return x

WindowAttention

WindowAttention就是在 Vision Transformer 模块的Attention基础上加入了相对位置偏移

  1. relative_position_bias_table

(即论文中提出的 Relative Position Bias)来提升精度:
在这里插入图片描述
生成相对位置偏置的过程(以一个head为例,假设

  1. window_h = 2,window_w = 2

,相关代码已在图中标出):
1.随机生成相对位置偏置表

  1. relative_position_bias_table


在这里插入图片描述

  1. self.relative_position_bias_table = nn.Parameter(
  2. torch.zeros((2* window_size[0]-1)*(2* window_size[1]-1), num_heads))# 2*Wh-1 * 2*Ww-1, nH

2.首先windows内部每个像素都有自己的位置编码,其绝对位置编码的坐标

  1. abs_coords

如果以左上角为原点,则如下图:
在这里插入图片描述

  1. coords_h = torch.arange(self.window_size[0])
  2. coords_w = torch.arange(self.window_size[1])
  3. coords = torch.stack(torch.meshgrid([coords_h, coords_w]))# 2, Wh, Ww
  4. coords_flatten = torch.flatten(coords,1)# 2, Wh*Ww

3.为了获得每个绝对坐标相对于其他坐标的相对位置,则需要用每个绝对坐标减去其他绝对坐标,即:

  • 用 ( 0 , 0 ) (0,0) (0,0) 分别减去 ( 0 , 1 ) (0,1) (0,1), ( 1 , 0 ) (1,0) (1,0), ( 1 , 1 ) (1,1) (1,1)
  • 用 ( 0 , 1 ) (0,1) (0,1) 分别减去 ( 0 , 0 ) (0,0) (0,0), ( 1 , 0 ) (1,0) (1,0), ( 1 , 1 ) (1,1) (1,1)
  • 用 ( 1 , 0 ) (1,0) (1,0) 分别减去 ( 0 , 0 ) (0,0) (0,0), ( 0 , 1 ) (0,1) (0,1), ( 1 , 1 ) (1,1) (1,1)
  • 用 ( 1 , 1 ) (1,1) (1,1) 分别减去 ( 0 , 0 ) (0,0) (0,0), ( 0 , 1 ) (0,1) (0,1), ( 1 , 0 ) (1,0) (1,0)

以此得到相对坐标

  1. relative_coords

。这样每一个像素相对于其他像素的相对位置关系就得到了。
代码中实现过程如下:
在这里插入图片描述

  1. relative_coords = coords_flatten[:,:,None]- coords_flatten[:,None,:]# 2, Wh*Ww, Wh*Ww
  2. relative_coords = relative_coords.permute(1,2,0).contiguous()# Wh*Ww, Wh*Ww, 2

4.将二维的坐标转换成一维的索引

  1. relative_position_index

  • 先将每一行加上 w i n d o w − h − 1 window_-h-1 window−​h−1,每一列加上 w i n d o w − w − 1 window_-w-1 window−​w−1
  • 然后将每一行乘上 2 ∗ w i n d o w − w − 1 2*window_-w-1 2∗window−​w−1
  • 最后将行列坐标相加,得到relative_position_index

在这里插入图片描述

  1. relative_coords[:,:,0]+= self.window_size[0]-1# shift to start from 0
  2. relative_coords[:,:,1]+= self.window_size[1]-1
  3. relative_coords[:,:,0]*=2* self.window_size[1]-1
  4. relative_position_index = relative_coords.sum(-1)# Wh*Ww, Wh*Ww

这里作用的用意是为了让不同的相对关系只对应一种数字。如果一开始直接将行列坐标相加,会发现

  1. (
  2. 0
  3. ,
  4. 1
  5. )
  6. (0,1)
  7. (0,1)
  8. (
  9. 1
  10. ,
  11. 0
  12. )
  13. (1,0)
  14. (1,0) 这两种关系无法分辨(相加都是1)。而如果不乘上
  15. 2
  16. w
  17. i
  18. n
  19. d
  20. o
  21. w
  22. w
  23. 1
  24. 2*window_-w-1
  25. 2window−​w1,会发现
  26. (
  27. 1
  28. ,
  29. 1
  30. )
  31. (1,1)
  32. (1,1)
  33. (
  34. 2
  35. ,
  36. 0
  37. )
  38. (2,0)
  39. (2,0) 这两种关系无法分辨(相加都是2)。所以经过几次变换以后再转换成1维的坐标就可以得到唯一的相对位置关系,如下图:

在这里插入图片描述

  • 0代表右下的相对位置关系
  • 1代表正下的相对位置关系
  • 2代表坐下的相对位置关系
  • 3代表正右的相对位置关系
  • 4代表和自己的相对位置关系
  • 5代表正左的相对位置关系
  • 6代表右上的相对位置关系
  • 7代表正上的相对位置关系
  • 8代表左上的相对位置关系

5.通过相对位置索引

  1. relative_position_index

在相对位置偏置表

  1. relative_position_bias_table

中取找对应的相对位置偏置

  1. relative_position_bias


在这里插入图片描述

  1. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  2. self.window_size[0]* self.window_size[1], self.window_size[0]* self.window_size[1],-1)# Wh*Ww,Wh*Ww,nH
  1. WindowAttention

的代码和注释如下(大部分代码在Vision Transformer(Pytorch版)代码阅读注释的Multi-Head Attention部分已经介绍,主要还是相对位置偏置部分不同):

  1. classWindowAttention(nn.Module):""" Window based multi-head self attention (W-MSA) module with relative position bias.
  2. It supports both of shifted and non-shifted window.
  3. Args:
  4. dim (int): 输入特征图的通道
  5. window_size (tuple[int]): window的尺寸.
  6. num_heads (int): muti-head self-attention的head个数.
  7. qkv_bias (bool, optional): 是否使用 qkv 偏置(即使用 Linear 将输入映射到 qkv 时,Linear是否使用 bias )
  8. qk_scale (float | None, optional): qk缩放比例,默认使用根号 dim_k 分之一
  9. attn_drop (float, optional): attention 中的 dropout 概率
  10. proj_drop (float, optional): linear 中的 dropout 概率
  11. """def__init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):'''
  12. VIT源码阅读中已经讲过了
  13. '''super().__init__()
  14. self.dim = dim
  15. self.window_size = window_size # Wh, Ww
  16. self.num_heads = num_heads
  17. head_dim = dim // num_heads
  18. self.scale = qk_scale or head_dim **-0.5'''
  19. 上述生成相对位置偏置的过程对应的代码:
  20. 1.生成相对位置偏置表
  21. relative_position_bias_table.shape = [2*window_h-1 * 2*window_w-1, num_heads]
  22. '''
  23. self.relative_position_bias_table = nn.Parameter(
  24. torch.zeros((2* window_size[0]-1)*(2* window_size[1]-1), num_heads))# 2*Wh-1 * 2*Ww-1, nH'''
  25. 2.生成绝对位置坐标
  26. coords.shape = [2, window_h, window_w]
  27. '''
  28. coords_h = torch.arange(self.window_size[0])
  29. coords_w = torch.arange(self.window_size[1])
  30. coords = torch.stack(torch.meshgrid([coords_h, coords_w]))# 2, Wh, Ww'''
  31. 3.生成相对位置坐标
  32. relative_coords.shape = [window_h*window_w, window_h*window_w, 2]
  33. '''
  34. coords_flatten = torch.flatten(coords,1)# 2, Wh*Ww
  35. relative_coords = coords_flatten[:,:,None]- coords_flatten[:,None,:]# 2, Wh*Ww, Wh*Ww
  36. relative_coords = relative_coords.permute(1,2,0).contiguous()# Wh*Ww, Wh*Ww, 2'''
  37. 4.将二维相对位置坐标转换成一维相对位置索引
  38. relative_position_index = [window_h*window_w, window_h*window_w]
  39. '''
  40. relative_coords[:,:,0]+= self.window_size[0]-1# shift to start from 0
  41. relative_coords[:,:,1]+= self.window_size[1]-1
  42. relative_coords[:,:,0]*=2* self.window_size[1]-1
  43. relative_position_index = relative_coords.sum(-1)# Wh*Ww, Wh*Ww
  44. self.register_buffer("relative_position_index", relative_position_index)'''
  45. VIT源码阅读中已经讲过了,是一模一样的
  46. '''
  47. self.qkv = nn.Linear(dim, dim *3, bias=qkv_bias)
  48. self.attn_drop = nn.Dropout(attn_drop)
  49. self.proj = nn.Linear(dim, dim)
  50. self.proj_drop = nn.Dropout(proj_drop)
  51. trunc_normal_(self.relative_position_bias_table, std=.02)
  52. self.softmax = nn.Softmax(dim=-1)defforward(self, x, mask=None):"""
  53. Args:
  54. x: shape为(num_windows*B, N, C)的特征图
  55. mask: 为了使SW-MSA中不相邻的子窗口之间的不进行qk匹配,论文阅读笔记中已介绍,
  56. mask的生成和使用原理在后面介绍SwinTransformerBlock的代码时会介绍
  57. """
  58. B_, N, C = x.shape
  59. '''
  60. 将VIY和attention时已经讲过了,不再介绍
  61. '''
  62. qkv = self.qkv(x).reshape(B_, N,3, self.num_heads, C // self.num_heads).permute(2,0,3,1,4)
  63. q, k, v = qkv[0], qkv[1], qkv[2]# make torchscript happy (cannot use tensor as tuple)
  64. q = q * self.scale
  65. attn =(q @ k.transpose(-2,-1))'''
  66. 5.根据相对位置索引在相对位置偏置表中找到对应的相对位置偏置并加上
  67. '''
  68. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  69. self.window_size[0]* self.window_size[1], self.window_size[0]* self.window_size[1],-1)# Wh*Ww,Wh*Ww,nH
  70. relative_position_bias = relative_position_bias.permute(2,0,1).contiguous()# nH, Wh*Ww, Wh*Ww
  71. attn = attn + relative_position_bias.unsqueeze(0)'''
  72. 利用mask掩膜计算将不相邻的子窗口之间使用softmax抑制来去除qk匹配
  73. '''if mask isnotNone:
  74. nW = mask.shape[0]
  75. attn = attn.view(B_ // nW, nW, self.num_heads, N, N)+ mask.unsqueeze(1).unsqueeze(0)
  76. attn = attn.view(-1, self.num_heads, N, N)
  77. attn = self.softmax(attn)else:
  78. attn = self.softmax(attn)
  79. attn = self.attn_drop(attn)
  80. x =(attn @ v).transpose(1,2).reshape(B_, N, C)
  81. x = self.proj(x)
  82. x = self.proj_drop(x)return x

Block

在论文阅读笔记:Swin Transformer中讲过Efficient batch computation for shifted configuration
在这里插入图片描述
其将A、B、C三个框中的 window 移动到四个 4 × 4 红色框的对应位置,使其凑成四个 4 × 4 的window。由于有几个 window 是由不相邻的子窗口组成,需要通过Masked MSA 掩膜计算来限制每个 window 中的不同子窗口的 MSA。

掩膜计算相关代码如下:

  1. '''
  2. self.input_resolution为当前特征图的大小
  3. img_mask 用于记录同特征图上每个像素的掩膜权重
  4. h_slices w_slices 分别为宽高切片
  5. self.shift_size 只有为SW-MSA计算时为self.window_size//2,W-MSA计算时为0
  6. '''
  7. H, W = self.input_resolution
  8. img_mask = torch.zeros((1, H, W,1))# 1 H W 1
  9. h_slices =(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))
  10. w_slices =(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))'''
  11. 将不同区域按0到8进行标记
  12. '''
  13. cnt =0for h in h_slices:for w in w_slices:
  14. img_mask[:, h, w,:]= cnt
  15. cnt +=1
  16. mask_windows = window_partition(img_mask, self.window_size)# nW, window_size, window_size, 1
  17. mask_windows = mask_window s.view(-1, self.window_size * self.window_size)
  18. attn_mask = mask_windows.unsqueeze(1)- mask_windows.unsqueeze(2)
  19. attn_mask = attn_mask.masked_fill(attn_mask !=0,float(-100.0)).masked_fill(attn_mask ==0,float(0.0))

过程如下(假设

  1. self.input_resolution=(8,8)

  1. self.window_size=4

  1. self.shift_size=self.window_size//2=2

):
1.按特征图中的 window 通过切片方式生成划分子窗口的

  1. img_mask


在这里插入图片描述
2.对

  1. img_mask

中不同区域按0到8进行标记:
在这里插入图片描述
3.将window拆分并展开:
在这里插入图片描述
4.每个window中分别用每个像素的标记值互减,计算像素之间是否在同一个子window中得到

  1. attn_mask(attn_mask.shape=[4,16,16])

  1. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

举例说明:
第一个window标记向量为:

  1. [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]

,依次减去每一个标记值,进行了16×16次减法,得到16×16的矩阵

  1. [[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],...,[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]]

第二个window标记向量为:

  1. [1,1,2,2,1,1,2,2,1,1,2,2,1,1,2,2]

,依次减去每一个标记值,进行了16×16次减法,得到16×16的矩阵

  1. [[0,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1],...,[-1,-1,0,0,-1,-1,0,0,-1,-1,0,0,-1,-1,0,0]]

第三个window标记向量为:

  1. [3,3,3,3,3,3,3,3,6,6,6,6,6,6,6,6]

,依次减去每一个标记值,进行了16×16次减法,得到16×16的矩阵

  1. [[0,0,0,0,0,0,0,0,3,3,3,3,3,3,3,3],...,[0,0,0,0,0,0,0,0,-3,-3,-3,-3,-3,-3,-3,-3]]

第四个window标记向量为:

  1. [4,4,5,5,4,4,5,5,7,7,8,8,7,7,8,8]

,依次减去每一个标记值,进行了16×16次减法,得到16×16的矩阵

  1. [[0,0,1,1,0,0,1,1,3,3,4.4,3,3,4,4],...,[-4,-4,-3,-3,-4,-4,-3,-3,-1,-1,0,0,-1,-1,0,0]]
  1. attn_mask

中某个window第

  1. i
  2. i
  3. i 行第
  4. j
  5. j
  6. j 列如果为 0,则表示第
  7. i
  8. i
  9. i 个像素和第
  10. j
  11. j
  12. j 个像素属于同一个子窗口。

5.给

  1. attn_mask

赋值(如果为0,则等于0,不为0,则等于-100):

  1. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

结合

  1. WindowAttention

的代码:

  1. nW = mask.shape[0]
  2. attn = attn.view(B_ // nW, nW, self.num_heads, N, N)+ mask.unsqueeze(1).unsqueeze(0)
  3. attn = attn.view(-1, self.num_heads, N, N)
  4. attn = self.softmax(attn)

其实现原理如下(以第0个像素为例):

第0个像素与其他像素的再完成

  1. Q
  2. K
  3. T
  4. /
  5. d
  6. +
  7. B
  8. QK^T/\sqrt{d}+B
  9. QKT/d​+B的计算后,得到
  10. α
  11. 0
  12. ,
  13. 0
  14. \alpha_{0,0}
  15. α0,0 ,
  16. α
  17. 0
  18. ,
  19. 1
  20. \alpha_{0,1}
  21. α0,1​…
  22. α
  23. 0
  24. ,
  25. 15
  26. \alpha_{0,15}
  27. α0,15​,再进行掩膜计算
  1. attn = attn + mask

,这样与像素0不在同一子窗口的像素结果就会减去100,最后通过softmax使其等于0,这样就抑制了不同窗口间的qk匹配。
再进行掩膜计算,
在这里插入图片描述
那么代码是如何实现子窗口之间的移动的呢?
窗口移动代码如下:

  1. '''
  2. X.shape = [B,H,W,C]
  3. 在 H 所在维度上滚动-self.shift_size个像素,
  4. 在 W 所在维度滚动-self.shift_size个像素
  5. shift_size = window_size//2
  6. 注:torch.roll正方向是往下
  7. '''if self.shift_size >0:
  8. shifted_x = torch.roll(x, shifts=(-self.shift_size,-self.shift_size), dims=(1,2))else:
  9. shifted_x = x

其实现过程如下(假设

  1. window_size = 4

,

  1. shift_size = window_size //2 = 2

):
在这里插入图片描述
接着对图像进行拆分即可:
在这里插入图片描述

  1. SwinTransformerBlock

的代码和注释如下:

  1. classSwinTransformerBlock(nn.Module):""" Swin Transformer Block.
  2. Args:
  3. dim (int): 特征图的维度.
  4. input_resolution (tuple[int]): 输入特征图的分辨率.
  5. num_heads (int): muti-head self-attention的headg个数.
  6. window_size (int): 窗口尺寸.
  7. shift_size (int): SW-MSA的偏移尺寸,为window_size // 2.
  8. mlp_ratio (float): patch embedding通过MLP的通道倍数.
  9. qkv_bias (bool): 使用 Linear 将输入映射到 qkv 时,Linear是否使用偏置. 默认: True
  10. qk_scale (float):qk缩放比例,如果是 None 则使用根号 dim_k 分之一. 默认: None
  11. drop (float, optional): dropout概率. 默认: 0
  12. attn_drop (float, optional): attention 中的 dropout 概率. 默认: 0
  13. drop_path (float | tuple[float], optional): attention 中的 droppath 概率. 默认: 0.1
  14. act_layer (nn.Module, optional): 激活函数. 默认: nn.GELU
  15. norm_layer (nn.Module, optional): 归一化层. 默认: nn.LayerNorm
  16. """def__init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
  17. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  18. act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()
  19. self.dim = dim
  20. self.input_resolution = input_resolution
  21. self.num_heads = num_heads
  22. self.window_size = window_size
  23. self.shift_size = shift_size
  24. self.mlp_ratio = mlp_ratio
  25. '''
  26. 如果窗口尺寸已经大于分辨率,则让窗口尺寸等于分辨率H、W 中较小的那一个
  27. '''ifmin(self.input_resolution)<= self.window_size:# if window size is larger than input resolution, we don't partition windows
  28. self.shift_size =0
  29. self.window_size =min(self.input_resolution)assert0<= self.shift_size < self.window_size,"shift_size must in 0-window_size"
  30. self.norm1 = norm_layer(dim)
  31. self.attn = WindowAttention(
  32. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  33. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  34. self.drop_path = DropPath(drop_path)if drop_path >0.else nn.Identity()
  35. self.norm2 = norm_layer(dim)
  36. mlp_hidden_dim =int(dim * mlp_ratio)
  37. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)'''
  38. 如果偏移量大于0,则表示使用SW-MSA
  39. 计算attn_mask,原理如上述过程。
  40. '''if self.shift_size >0:# calculate attention mask for SW-MSA
  41. H, W = self.input_resolution
  42. img_mask = torch.zeros((1, H, W,1))# 1 H W 1
  43. h_slices =(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))
  44. w_slices =(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))
  45. cnt =0for h in h_slices:for w in w_slices:
  46. img_mask[:, h, w,:]= cnt
  47. cnt +=1
  48. mask_windows = window_partition(img_mask, self.window_size)# nW, window_size, window_size, 1
  49. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  50. attn_mask = mask_windows.unsqueeze(1)- mask_windows.unsqueeze(2)
  51. attn_mask = attn_mask.masked_fill(attn_mask !=0,float(-100.0)).masked_fill(attn_mask ==0,float(0.0))else:
  52. attn_mask =None
  53. self.register_buffer("attn_mask", attn_mask)defforward(self, x):
  54. H, W = self.input_resolution
  55. B, L, C = x.shape
  56. assert L == H * W,"input feature has wrong size"
  57. shortcut = x
  58. x = self.norm1(x)
  59. x = x.view(B, H, W, C)'''
  60. H 所在维度上滚动-self.shift_size个像素,
  61. W 所在维度滚动-self.shift_size个像素
  62. 注:torch.roll正方向是从上往下和从左往右
  63. '''if self.shift_size >0:
  64. shifted_x = torch.roll(x, shifts=(-self.shift_size,-self.shift_size), dims=(1,2))else:
  65. shifted_x = x
  66. '''
  67. 划分窗口
  68. '''
  69. x_windows = window_partition(shifted_x, self.window_size)# nW*B, window_size, window_size, C
  70. x_windows = x_windows.view(-1, self.window_size * self.window_size, C)# nW*B, window_size*window_size, C'''
  71. 每个window进行MSA计算
  72. '''
  73. attn_windows = self.attn(x_windows, mask=self.attn_mask)# nW*B, window_size*window_size, C'''
  74. 合并窗口
  75. '''
  76. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  77. shifted_x = window_reverse(attn_windows, self.window_size, H, W)# B H' W' C'''
  78. 如果是SW-MSA还需要将窗口滑动回来
  79. '''if self.shift_size >0:
  80. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1,2))else:
  81. x = shifted_x
  82. x = x.view(B, H * W, C)'''
  83. 残差结构
  84. '''
  85. x = shortcut + self.drop_path(x)
  86. x = x + self.drop_path(self.mlp(self.norm2(x)))return x

PatchMerging

论文阅读笔记:Swin Transformer已对PatchMerging进行了介绍,其流程图如下(最终通道翻倍,大小变为原来的一半达到降采样的效果):
在这里插入图片描述

  1. PatchMerging

的代码和注释如下:

  1. classPatchMerging(nn.Module):""" Patch Merging Layer.
  2. Args:
  3. input_resolution (tuple[int]): 输入特征图的大小
  4. dim (int): 输入特征图的通道
  5. norm_layer (nn.Module, optional): 归一化层. Default: nn.LayerNorm
  6. """def__init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):super().__init__()
  7. self.input_resolution = input_resolution
  8. self.dim = dim
  9. '''
  10. 将输出通道在最初的特征图通道上翻倍(即流程图中最4个步骤和第1个步骤的倍数关系)
  11. '''
  12. self.reduction = nn.Linear(4* dim,2* dim, bias=False)
  13. self.norm = norm_layer(4* dim)defforward(self, x):"""
  14. x: B, H*W, C
  15. """
  16. H, W = self.input_resolution
  17. B, L, C = x.shape
  18. assert L == H * W,"input feature has wrong size"assert H %2==0and W %2==0,f"x size ({H}*{W}) are not even."
  19. x = x.view(B, H, W, C)'''
  20. 将输入按流程图中的方式拆分成4份
  21. 在concat以后进行flatten
  22. '''
  23. x0 = x[:,0::2,0::2,:]# B H/2 W/2 C
  24. x1 = x[:,1::2,0::2,:]# B H/2 W/2 C
  25. x2 = x[:,0::2,1::2,:]# B H/2 W/2 C
  26. x3 = x[:,1::2,1::2,:]# B H/2 W/2 C
  27. x = torch.cat([x0, x1, x2, x3],-1)# B H/2 W/2 4*C
  28. x = x.view(B,-1,4* C)# B H/2*W/2 4*C'''
  29. 进行LayerNorm并再用一个全连接层输出
  30. '''
  31. x = self.norm(x)
  32. x = self.reduction(x)return x

本文转载自: https://blog.csdn.net/Z960515/article/details/122784141
版权归原作者 HollowKnightZ 所有, 如有侵权,请联系我们删除。

“Swin Transformer代码阅读注释”的评论:

还没有评论