0


GCNet: Global Context Network(ICCV 2019)原理与代码解析

paper:GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond

official implementaion:https://github.com/xvjiarui/GCNet

Third party implementation:https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/context_block.py

存在的问题

通过捕获long-range dependency提取全局信息,对各种视觉任务都是很有帮助的。Non-local Network(介绍见https://blog.csdn.net/ooooocj/article/details/124573078)通过自注意力机制来解决这个问题。对于每个查询位置(query position),non-local network首先计算该位置和所有位置之间一个两两成对的关系,得到一个attention map。然后对attention map所有位置的权重加权求和得到汇总特征,每一个查询位置都得到一个汇总特征,将汇总特征与原始特征相加得到最终输出。

对于某个query position,non-local network计算的另一个位置与该位置的关系即一个权重值表示这个位置对query位置的重要程度。本文可视化attention map发现,不同的query位置其对应的attention map几乎一样,如下图所示

non-local block可以表示为下式

其中 (i) 是query position的索引,(j) 遍历所有位置,(f(\mathbf{x}{i},\mathbf{x}{j})) 表示位置 (i) 和 (j) 之间的关系,(\mathcal{C}(\mathbf{x})) 是归一化因子,(W_{z}) 和 (W_{v}) 是线性变换矩阵例如 (1\times1) 卷积。non-local block有多种不同的实例化方法,例如Gaussian、Embedded Gaussian、Dot product、Concat,下图(a)是Embedded Gaussian的结构。

由于需要计算每个query位置的attention map,因此non-local block的时间和空间复杂度都是所有位置的平方关系。

下图是作者从COCO数据集中随机挑选的6张图片,并可视化出3个不同的query position即图中的红点与对应的query-specific attention map,可以看出对于不同的查询位置,它们的attention map几乎是相同的。

为了进一步验证这一观察结果,作者又分析了不同的查询位置与全局上下文之间的距离。结果如下表所示。其中计算了三种向量之间的余弦距离cosine distance,分别为non-local block的输入、输出、以及查询的注意力图,对应表中的input、output、att。

从表中可以看出,input列的余弦距离比较大表明non-local的输入特征可以在不同的位置进行区分,但output列的余弦距离非常小,表明non-local block建模的全局上下文特征对于不同的query position几乎是相同的。attention map上的距离非常小,也验证了可视化的观察结果。

尽管non-local block是打算针对每个位置计算全局上下文的,但训练后的全局上下文实际上是独立于查询位置的。因为没有必要为每个查询位置单独计算query-specific全局上下文。

本文的创新点

本文通过观察发现non-local block针对每个query position计算的attention map最终结果是独立于查询位置的,那么就没有必要针对每个查询位置计算了,因此提出计算一个通用的attention map并应用于输入feature map上的所有位置,大大减少了计算量的同时又没有导致性能的降低。此外,结合SE block,设计了一个新的Global Context (GC) block,既轻量又可以有效地建模全局上下文。GC Block结合了Non-local block和SE block的优点,基于GC Block设计的GCNet在多个任务上均超过了NLNet和SENet。

方法介绍

作者舍去了式(1)中的 (W_{z}) 即图(3)(a)中的query分支,得到下式

这里采用的是最常用的Embedded Gaussiian的实例化方式。简化后的non-local block如图(3)(b)所示。

为了进一步减少计算量,作者应用分配率将 (W_{v}) 移到attention pool的外面,如下

这里简化后的non-local block如图4(b)所示。1x1卷积 (W_{v}) 的FLOPs从 (\mathcal{O}(HWC^{2})) 减小到 (\mathcal{O}(C^{2}))。

到目前为止简化的NL block中参数量最大的部分在transform module即图4(b)中的Transform部分,这里是一个1x1卷积但参数量为 (C\cdot C),当把Nl block应用到较深的层例如resnet中的 (res_{5}) 时,CxC=2028x2048,占据了整个block大部分的计算量。为了进一步减少计算量,作者借鉴了SE block的思想如图4(c),将图4(b)中的transform module换成了图4(d)中的bottleneck transform module,其中 (r) 是reduction ratio,这样参数量就从C·C变成了2·C·C/r,默认情况下r=16,因此参数量就减少为1/8。

实验结果

下表baseline是backbone为ResNet-50的Mask R-CNN在COCO数据集上的目标检测和实例分割的结果。将1个non-local block(NL)、1个simplified non-local block(SNL)、1个global context block(GC)插入到c4的最后一个residual block前,可以看出GC block获得的相似的性能但参数量更小。将GC block添加到所有的residual block中在参数量相似的情况下得到了更高的性能。

代码解析

这里的代码是mmcv中的实现。

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import Union
  3. import torch
  4. from torch import nn
  5. from ..utils import constant_init, kaiming_init
  6. from .registry import PLUGIN_LAYERS
  7. def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
  8. if isinstance(m, nn.Sequential):
  9. constant_init(m[-1], val=0)
  10. else:
  11. constant_init(m, val=0)
  12. @PLUGIN_LAYERS.register_module()
  13. class ContextBlock(nn.Module):
  14. """ContextBlock module in GCNet.
  15. See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
  16. (https://arxiv.org/abs/1904.11492) for details.
  17. Args:
  18. in_channels (int): Channels of the input feature map.
  19. ratio (float): Ratio of channels of transform bottleneck
  20. pooling_type (str): Pooling method for context modeling.
  21. Options are 'att' and 'avg', stand for attention pooling and
  22. average pooling respectively. Default: 'att'.
  23. fusion_types (Sequence[str]): Fusion method for feature fusion,
  24. Options are 'channels_add', 'channel_mul', stand for channelwise
  25. addition and multiplication respectively. Default: ('channel_add',)
  26. """
  27. _abbr_ = 'context_block'
  28. def __init__(self,
  29. in_channels: int,
  30. ratio: float,
  31. pooling_type: str = 'att',
  32. fusion_types: tuple = ('channel_add', )):
  33. super().__init__()
  34. assert pooling_type in ['avg', 'att']
  35. assert isinstance(fusion_types, (list, tuple))
  36. valid_fusion_types = ['channel_add', 'channel_mul']
  37. assert all([f in valid_fusion_types for f in fusion_types])
  38. assert len(fusion_types) > 0, 'at least one fusion should be used'
  39. self.in_channels = in_channels
  40. self.ratio = ratio
  41. self.planes = int(in_channels * ratio)
  42. self.pooling_type = pooling_type
  43. self.fusion_types = fusion_types
  44. if pooling_type == 'att':
  45. self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
  46. self.softmax = nn.Softmax(dim=2)
  47. else:
  48. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  49. if 'channel_add' in fusion_types:
  50. self.channel_add_conv = nn.Sequential(
  51. nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
  52. nn.LayerNorm([self.planes, 1, 1]),
  53. nn.ReLU(inplace=True), # yapf: disable
  54. nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
  55. else:
  56. self.channel_add_conv = None
  57. if 'channel_mul' in fusion_types:
  58. self.channel_mul_conv = nn.Sequential(
  59. nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
  60. nn.LayerNorm([self.planes, 1, 1]),
  61. nn.ReLU(inplace=True), # yapf: disable
  62. nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
  63. else:
  64. self.channel_mul_conv = None
  65. self.reset_parameters()
  66. def reset_parameters(self):
  67. if self.pooling_type == 'att':
  68. kaiming_init(self.conv_mask, mode='fan_in')
  69. self.conv_mask.inited = True
  70. if self.channel_add_conv is not None:
  71. last_zero_init(self.channel_add_conv)
  72. if self.channel_mul_conv is not None:
  73. last_zero_init(self.channel_mul_conv)
  74. def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
  75. batch, channel, height, width = x.size()
  76. if self.pooling_type == 'att':
  77. input_x = x
  78. # [N, C, H * W]
  79. input_x = input_x.view(batch, channel, height * width)
  80. # [N, 1, C, H * W]
  81. input_x = input_x.unsqueeze(1)
  82. # [N, 1, H, W]
  83. context_mask = self.conv_mask(x)
  84. # [N, 1, H * W]
  85. context_mask = context_mask.view(batch, 1, height * width)
  86. # [N, 1, H * W]
  87. context_mask = self.softmax(context_mask)
  88. # [N, 1, H * W, 1]
  89. context_mask = context_mask.unsqueeze(-1)
  90. # [N, 1, C, 1]
  91. context = torch.matmul(input_x, context_mask)
  92. # [N, C, 1, 1]
  93. context = context.view(batch, channel, 1, 1)
  94. else:
  95. # [N, C, 1, 1]
  96. context = self.avg_pool(x)
  97. return context
  98. def forward(self, x: torch.Tensor) -> torch.Tensor:
  99. # [N, C, 1, 1]
  100. context = self.spatial_pool(x)
  101. out = x
  102. if self.channel_mul_conv is not None:
  103. # [N, C, 1, 1]
  104. channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
  105. out = out * channel_mul_term
  106. if self.channel_add_conv is not None:
  107. # [N, C, 1, 1]
  108. channel_add_term = self.channel_add_conv(context)
  109. out = out + channel_add_term
  110. return out

本文转载自: https://blog.csdn.net/ooooocj/article/details/129655090
版权归原作者 00000cj 所有, 如有侵权,请联系我们删除。

“GCNet: Global Context Network(ICCV 2019)原理与代码解析”的评论:

还没有评论