0


VAN(大核注意力机制)

Visual-Attention-Network/VAN-Classification (github.com)

transformer在视觉领域得到良好的效果,是因为它可以捕捉长距离的信息。在视觉领域,通常有两种办法去获得长距离的信息,一是基于transformer的自注意力机制 ,二是大内核卷积。自注意力机制源于NLP,虽然在视觉领域得到很好的效果,但是仍然存在一些问题。比如说自注意力机制将2维的图像数据展开破坏了图像2D结构,而且其计算量和内存占用也比较大。大内核卷积,会引入大量的参数和计算量。作者基于这些问题,提出了大核注意力机制(LKA)。大内核注意力机制结合了卷积运算的局部感受野和旋转不变性和自注意力机制的长距离信息。

LKA

类似于mobilenet的深度可分离卷积,将一个大内核卷积分解。

将一个卷积核大小为K的卷积分解为三个卷积的和,分别是卷积核大小为K/d的深度卷积、卷积核大小为(2d-1)膨胀率为d的深度膨胀卷积,通道卷积(1*1卷积)。

下表介绍了卷积,自注意力机制,LKA(大核注意力机制)的特点

  1. class LKA(nn.Module):
  2. def __init__(self, dim):
  3. super().__init__()
  4. self.conv0 = nn.Conv2d(dim, dim, 7, padding=7//2, groups=dim) ###深度可分离卷积 卷积核的大小(2d-1)
  5. self.conv_spatial = nn.Conv2d(dim, dim, 9, stride=1, padding=((9//2)*4), groups=dim, dilation=4) ###空洞率为4的深度可分离卷积 (卷积核大小 K/d)
  6. self.conv1 = nn.Conv2d(dim, dim, 1) ###逐点卷积
  7. def forward(self, x):
  8. u = x.clone()
  9. attn = self.conv0(x)
  10. attn = self.conv_spatial(attn)
  11. attn = self.conv1(attn)
  12. return u * attn

VAN

VAN结构是一个非常简单的层次结构,具有四个阶段,每个阶段的图像分辨率减半

作者在每一个阶段采用卷积核的步长控制下采样的幅度,然后就是堆叠下面的结构进行特征的提取。

  1. class Attention(nn.Module):
  2. def __init__(self, d_model):
  3. super().__init__()
  4. self.proj_1 = nn.Conv2d(d_model, d_model, 1)
  5. self.activation = nn.GELU()
  6. self.spatial_gating_unit = LKA(d_model)
  7. self.proj_2 = nn.Conv2d(d_model, d_model, 1)
  8. def forward(self, x):
  9. shorcut = x.clone()
  10. x = self.proj_1(x)
  11. x = self.activation(x)
  12. x = self.spatial_gating_unit(x)
  13. x = self.proj_2(x)
  14. x = x + shorcut
  15. return x
  16. class Block(nn.Module):
  17. def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., act_layer=nn.GELU):
  18. super().__init__()
  19. self.norm1 = nn.BatchNorm2d(dim)
  20. self.attn = Attention(dim)
  21. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  22. self.norm2 = nn.BatchNorm2d(dim)
  23. mlp_hidden_dim = int(dim * mlp_ratio)
  24. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  25. layer_scale_init_value = 1e-2
  26. self.layer_scale_1 = nn.Parameter(
  27. layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  28. self.layer_scale_2 = nn.Parameter(
  29. layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  30. self.apply(self._init_weights)
  31. def _init_weights(self, m):
  32. if isinstance(m, nn.Linear):
  33. trunc_normal_(m.weight, std=.02)
  34. if isinstance(m, nn.Linear) and m.bias is not None:
  35. nn.init.constant_(m.bias, 0)
  36. elif isinstance(m, nn.LayerNorm):
  37. nn.init.constant_(m.bias, 0)
  38. nn.init.constant_(m.weight, 1.0)
  39. elif isinstance(m, nn.Conv2d):
  40. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  41. fan_out //= m.groups
  42. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  43. if m.bias is not None:
  44. m.bias.data.zero_()
  45. def forward(self, x):
  46. x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
  47. x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
  48. return x

实验

图像分类

设置

设置部分主要介绍了数据集的处理,一些数据增强的手段,模型训练的一些设置,具体请看原文。

消融实验

验证了LKA各个部分的有效性

上表中第一行数据,说明了深度卷积可以充分利用图像的局部上下文信息。没有它,准确率下降了0.5

上表中的第二行数据,说明了深度扩张卷积可以捕获长范围的依赖

第三行对应得结构是图3b,第四行数据对应图3 c

第五行数据表明了1*1卷积可以捕获通道维度的关系。

第六行数据,说明了图3 1 中没有必要存在sigmoid函数。sigmoid用于将注意力图归一化到0-1之间。

最后一行,是基线模型。

通过以上分析,我们可以发现我们提出的LKA可以利用本地信息,捕获长距离依赖关系,并且在通道和空间维度上都具有适应性。此外,实验结果证明所有组件对于识别任务都是有效的。尽管标准卷积可以充分利用本地上下文信息,但它忽略了长期依赖性和适应性。对于自我注意,尽管它可以捕获远程依赖并在空间维度上具有适应性,但它忽略了局部信息和信道维度上的适应性

表6 验证了卷积核大小的影响

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from functools import partial
  5. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  6. from timm.models.registry import register_model
  7. from timm.models.vision_transformer import _cfg
  8. import math
  9. class Mlp(nn.Module):
  10. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  11. super().__init__()
  12. out_features = out_features or in_features
  13. hidden_features = hidden_features or in_features
  14. self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
  15. self.dwconv = DWConv(hidden_features)
  16. self.act = act_layer()
  17. self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
  18. self.drop = nn.Dropout(drop)
  19. self.apply(self._init_weights)
  20. def _init_weights(self, m):
  21. if isinstance(m, nn.Linear):
  22. trunc_normal_(m.weight, std=.02)
  23. if isinstance(m, nn.Linear) and m.bias is not None:
  24. nn.init.constant_(m.bias, 0)
  25. elif isinstance(m, nn.LayerNorm):
  26. nn.init.constant_(m.bias, 0)
  27. nn.init.constant_(m.weight, 1.0)
  28. elif isinstance(m, nn.Conv2d):
  29. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  30. fan_out //= m.groups
  31. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  32. if m.bias is not None:
  33. m.bias.data.zero_()
  34. def forward(self, x):
  35. x = self.fc1(x)
  36. x = self.dwconv(x)
  37. x = self.act(x)
  38. x = self.drop(x)
  39. x = self.fc2(x)
  40. x = self.drop(x)
  41. return x
  42. class LKA(nn.Module):
  43. def __init__(self, dim):
  44. super().__init__()
  45. self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
  46. self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
  47. self.conv1 = nn.Conv2d(dim, dim, 1)
  48. def forward(self, x):
  49. u = x.clone()
  50. attn = self.conv0(x)
  51. attn = self.conv_spatial(attn)
  52. attn = self.conv1(attn)
  53. return u * attn
  54. class Attention(nn.Module):
  55. def __init__(self, d_model):
  56. super().__init__()
  57. self.proj_1 = nn.Conv2d(d_model, d_model, 1)
  58. self.activation = nn.GELU()
  59. self.spatial_gating_unit = LKA(d_model)
  60. self.proj_2 = nn.Conv2d(d_model, d_model, 1)
  61. def forward(self, x):
  62. shorcut = x.clone()
  63. x = self.proj_1(x)
  64. x = self.activation(x)
  65. x = self.spatial_gating_unit(x)
  66. x = self.proj_2(x)
  67. x = x + shorcut
  68. return x
  69. class Block(nn.Module):
  70. def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., act_layer=nn.GELU):
  71. super().__init__()
  72. self.norm1 = nn.BatchNorm2d(dim)
  73. self.attn = Attention(dim)
  74. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  75. self.norm2 = nn.BatchNorm2d(dim)
  76. mlp_hidden_dim = int(dim * mlp_ratio)
  77. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  78. layer_scale_init_value = 1e-2
  79. self.layer_scale_1 = nn.Parameter(
  80. layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  81. self.layer_scale_2 = nn.Parameter(
  82. layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  83. self.apply(self._init_weights)
  84. def _init_weights(self, m):
  85. if isinstance(m, nn.Linear):
  86. trunc_normal_(m.weight, std=.02)
  87. if isinstance(m, nn.Linear) and m.bias is not None:
  88. nn.init.constant_(m.bias, 0)
  89. elif isinstance(m, nn.LayerNorm):
  90. nn.init.constant_(m.bias, 0)
  91. nn.init.constant_(m.weight, 1.0)
  92. elif isinstance(m, nn.Conv2d):
  93. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  94. fan_out //= m.groups
  95. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  96. if m.bias is not None:
  97. m.bias.data.zero_()
  98. def forward(self, x):
  99. x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
  100. x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
  101. return x
  102. class OverlapPatchEmbed(nn.Module):
  103. """ Image to Patch Embedding
  104. """
  105. def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
  106. super().__init__()
  107. patch_size = to_2tuple(patch_size)
  108. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
  109. padding=(patch_size[0] // 2, patch_size[1] // 2))
  110. self.norm = nn.BatchNorm2d(embed_dim)
  111. self.apply(self._init_weights)
  112. def _init_weights(self, m):
  113. if isinstance(m, nn.Linear):
  114. trunc_normal_(m.weight, std=.02)
  115. if isinstance(m, nn.Linear) and m.bias is not None:
  116. nn.init.constant_(m.bias, 0)
  117. elif isinstance(m, nn.LayerNorm):
  118. nn.init.constant_(m.bias, 0)
  119. nn.init.constant_(m.weight, 1.0)
  120. elif isinstance(m, nn.Conv2d):
  121. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  122. fan_out //= m.groups
  123. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  124. if m.bias is not None:
  125. m.bias.data.zero_()
  126. def forward(self, x):
  127. x = self.proj(x)
  128. _, _, H, W = x.shape
  129. x = self.norm(x)
  130. return x, H, W
  131. class VAN(nn.Module):
  132. def __init__(self, img_size=224, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
  133. mlp_ratios=[4, 4, 4, 4], drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
  134. depths=[3, 4, 6, 3], num_stages=4, flag=False):
  135. super().__init__()
  136. if flag == False:
  137. self.num_classes = num_classes
  138. self.depths = depths
  139. self.num_stages = num_stages
  140. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  141. cur = 0
  142. for i in range(num_stages):
  143. patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
  144. patch_size=7 if i == 0 else 3,
  145. stride=4 if i == 0 else 2,
  146. in_chans=in_chans if i == 0 else embed_dims[i - 1],
  147. embed_dim=embed_dims[i])
  148. block = nn.ModuleList([Block(
  149. dim=embed_dims[i], mlp_ratio=mlp_ratios[i], drop=drop_rate, drop_path=dpr[cur + j])
  150. for j in range(depths[i])])
  151. norm = norm_layer(embed_dims[i])
  152. cur += depths[i]
  153. setattr(self, f"patch_embed{i + 1}", patch_embed)
  154. setattr(self, f"block{i + 1}", block)
  155. setattr(self, f"norm{i + 1}", norm)
  156. # classification head
  157. self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
  158. self.apply(self._init_weights)
  159. def _init_weights(self, m):
  160. if isinstance(m, nn.Linear):
  161. trunc_normal_(m.weight, std=.02)
  162. if isinstance(m, nn.Linear) and m.bias is not None:
  163. nn.init.constant_(m.bias, 0)
  164. elif isinstance(m, nn.LayerNorm):
  165. nn.init.constant_(m.bias, 0)
  166. nn.init.constant_(m.weight, 1.0)
  167. elif isinstance(m, nn.Conv2d):
  168. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  169. fan_out //= m.groups
  170. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  171. if m.bias is not None:
  172. m.bias.data.zero_()
  173. def freeze_patch_emb(self):
  174. self.patch_embed1.requires_grad = False
  175. @torch.jit.ignore
  176. def no_weight_decay(self):
  177. return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
  178. def get_classifier(self):
  179. return self.head
  180. def reset_classifier(self, num_classes, global_pool=''):
  181. self.num_classes = num_classes
  182. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  183. def forward_features(self, x):
  184. B = x.shape[0]
  185. for i in range(self.num_stages):
  186. patch_embed = getattr(self, f"patch_embed{i + 1}")
  187. block = getattr(self, f"block{i + 1}")
  188. norm = getattr(self, f"norm{i + 1}")
  189. x, H, W = patch_embed(x)
  190. for blk in block:
  191. x = blk(x)
  192. x = x.flatten(2).transpose(1, 2)
  193. x = norm(x)
  194. if i != self.num_stages - 1:
  195. x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  196. return x.mean(dim=1)
  197. def forward(self, x):
  198. x = self.forward_features(x)
  199. x = self.head(x)
  200. return x
  201. class DWConv(nn.Module):
  202. def __init__(self, dim=768):
  203. super(DWConv, self).__init__()
  204. self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
  205. def forward(self, x):
  206. x = self.dwconv(x)
  207. return x
  208. def _conv_filter(state_dict, patch_size=16):
  209. """ convert patch embedding weight from manual patchify + linear proj to conv"""
  210. out_dict = {}
  211. for k, v in state_dict.items():
  212. if 'patch_embed.proj.weight' in k:
  213. v = v.reshape((v.shape[0], 3, patch_size, patch_size))
  214. out_dict[k] = v
  215. return out_dict
  216. model_urls = {
  217. "van_b0": "https://huggingface.co/Visual-Attention-Network/VAN-Tiny-original/resolve/main/van_tiny_754.pth.tar",
  218. "van_b1": "https://huggingface.co/Visual-Attention-Network/VAN-Small-original/resolve/main/van_small_811.pth.tar",
  219. "van_b2": "https://huggingface.co/Visual-Attention-Network/VAN-Base-original/resolve/main/van_base_828.pth.tar",
  220. "van_b3": "https://huggingface.co/Visual-Attention-Network/VAN-Large-original/resolve/main/van_large_839.pth.tar",
  221. }
  222. def load_model_weights(model, arch, kwargs):
  223. url = model_urls[arch]
  224. checkpoint = torch.hub.load_state_dict_from_url(
  225. url=url, map_location="cpu", check_hash=True
  226. )
  227. strict = True
  228. if "num_classes" in kwargs and kwargs["num_classes"] != 1000:
  229. strict = False
  230. del checkpoint["state_dict"]["head.weight"]
  231. del checkpoint["state_dict"]["head.bias"]
  232. model.load_state_dict(checkpoint["state_dict"], strict=strict)
  233. return model
  234. @register_model
  235. def van_b0(pretrained=False, **kwargs):
  236. model = VAN(
  237. embed_dims=[32, 64, 160, 256], mlp_ratios=[8, 8, 4, 4],
  238. norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 3, 5, 2],
  239. **kwargs)
  240. model.default_cfg = _cfg()
  241. if pretrained:
  242. model = load_model_weights(model, "van_b0", kwargs)
  243. return model
  244. @register_model
  245. def van_b1(pretrained=False, **kwargs):
  246. model = VAN(
  247. embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4],
  248. norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 4, 2],
  249. **kwargs)
  250. model.default_cfg = _cfg()
  251. if pretrained:
  252. model = load_model_weights(model, "van_b1", kwargs)
  253. return model
  254. @register_model
  255. def van_b2(pretrained=False, **kwargs):
  256. model = VAN(
  257. embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4],
  258. norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 3, 12, 3],
  259. **kwargs)
  260. model.default_cfg = _cfg()
  261. if pretrained:
  262. model = load_model_weights(model, "van_b2", kwargs)
  263. return model
  264. @register_model
  265. def van_b3(pretrained=False, **kwargs):
  266. model = VAN(
  267. embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4],
  268. norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 5, 27, 3],
  269. **kwargs)
  270. model.default_cfg = _cfg()
  271. if pretrained:
  272. model = load_model_weights(model, "van_b3", kwargs)
  273. return model
  274. @register_model
  275. def van_b4(pretrained=False, **kwargs):
  276. model = VAN(
  277. embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4],
  278. norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3],
  279. **kwargs)
  280. model.default_cfg = _cfg()
  281. if pretrained:
  282. model = load_model_weights(model, "van_b4", kwargs)
  283. return model
  284. @register_model
  285. def van_b5(pretrained=False, **kwargs):
  286. model = VAN(
  287. embed_dims=[96, 192, 480, 768], mlp_ratios=[8, 8, 4, 4],
  288. norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 3, 24, 3],
  289. **kwargs)
  290. model.default_cfg = _cfg()
  291. if pretrained:
  292. model = load_model_weights(model, "van_b5", kwargs)
  293. return model
  294. @register_model
  295. def van_b6(pretrained=False, **kwargs):
  296. model = VAN(
  297. embed_dims=[96, 192, 384, 768], mlp_ratios=[8, 8, 4, 4],
  298. norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[6,6,90,6],
  299. **kwargs)
  300. model.default_cfg = _cfg()
  301. if pretrained:
  302. model = load_model_weights(model, "van_b6", kwargs)
  303. return model
  304. if __name__=='__main__':
  305. model=van_b0()
  306. print(model)

讨论

最近,基于transformer的模型迅速征服了各种视觉排行榜。众所周知,自我注意只是一种特殊的注意机制。但是,人们逐渐默认采用自我注意,而忽略了潜在的注意方法。本文提出了一种新颖的注意力模块LKA和基于CNN的网络VAN。它超越了最先进基于transformer的视觉任务方法。我们希望本文能促进人们重新思考自我注意是否是不可替代的,以及哪种注意更适合视觉任务。

未来的工作

结构本身的不断改进。

在本文中,我们仅演示了一个直观的结构。有很多潜在的改进,例如采用不同的内核大小,引入多尺度结构 [11] 和使用多分支结构 [10]。

大规模自我监督学习和迁移学习。

VAN自然地结合了CNNs和vit的优点。一方面,VAN可以利用图像的2D结构信息。另一方面,VAN可以根据输入图像动态调整输出,这适合自我监督学习和迁移学习 [59],[64]。结合以上两点,我们相信VAN可以在图像自我监督学习和迁移学习领域取得更好的表现。

更多应用领域。

由于资源有限,我们仅在视觉任务中表现出出色的性能。VANs能否在NLP中像TCN [122] 这样的其他领域表现出色,仍然值得探讨。我们期待看到VANs成为通用模型。


本文转载自: https://blog.csdn.net/qq_40107571/article/details/127257396
版权归原作者 南妮儿 所有, 如有侵权,请联系我们删除。

“VAN(大核注意力机制)”的评论:

还没有评论