0


如何在YOLOv8网络中添加自定义注意力机制

在目标检测任务中,加入注意力机制可以提升模型的检测效果。本文将介绍如何在YOLOv8模型中集成多种注意力机制,如

  1. SimAM

  1. ShuffleAttention

  1. TripletAttention

  1. MHSA

  1. CBAM

  1. EMA

,以增强模型对图像特征的提取能力。我们将展示每个注意力机制的代码示例,并讨论如何将这些模块添加到YOLOv8网络中。


目录


1. 注意力机制示例代码

下面介绍六种常用的注意力机制模块,并提供代码示例。每种注意力机制都有其独特的优点,可以根据任务需求选择最适合的机制。

1.1 SimAM 模块代码
  1. SimAM

(Simple Attention Module) 是一种轻量级的注意力机制;通过简单的操作实现了注意力机制的效果,适用于对计算资源敏感的项目。它适合那些希望在提升模型性能的同时,尽量减少计算开销的任务,比如嵌入式设备上的实时目标检测。。

  1. import torch
  2. import torch.nn as nn
  3. classSimAM(torch.nn.Module):def__init__(self, e_lambda=1e-4):super(SimAM, self).__init__()
  4. self.activaton = nn.Sigmoid()
  5. self.e_lambda = e_lambda
  6. def__repr__(self):
  7. s = self.__class__.__name__ +'('
  8. s +=('lambda=%f)'% self.e_lambda)return s
  9. @staticmethoddefget_module_name():return"simam"defforward(self, x):
  10. b, c, h, w = x.size()
  11. n = w * h -1
  12. x_minus_mu_square =(x - x.mean(dim=[2,3], keepdim=True)).pow(2)
  13. y = x_minus_mu_square /(4*(x_minus_mu_square.sum(dim=[2,3], keepdim=True)/ n + self.e_lambda))+0.5return x * self.activaton(y)# if __name__ == '__main__':# input = torch.randn(3, 64, 7, 7)# model = SimAM()# outputs = model(input)# print(outputs.shape)
1.2 ShuffleAttention 模块代码
  1. ShuffleAttention

适合需要全局特征交互的场景;通过通道洗牌操作重新排列特征,确保模型能够在不同通道间传递信息,提升特征的全局表达能力。对于需要处理复杂、具有多样性特征的图像(如交通场景、复杂的自然环境),这种机制能有效提升模型的感知能力。

  1. import torch
  2. from torch import nn
  3. from torch.nn import init
  4. from torch.nn.parameter import Parameter
  5. classShuffleAttention(nn.Module):def__init__(self, channel=512, reduction=16, G=8):super().__init__()
  6. self.G = G
  7. self.channel = channel
  8. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  9. self.gn = nn.GroupNorm(channel //(2* G), channel //(2* G))
  10. self.cweight = Parameter(torch.zeros(1, channel //(2* G),1,1))
  11. self.cbias = Parameter(torch.ones(1, channel //(2* G),1,1))
  12. self.sweight = Parameter(torch.zeros(1, channel //(2* G),1,1))
  13. self.sbias = Parameter(torch.ones(1, channel //(2* G),1,1))
  14. self.sigmoid = nn.Sigmoid()definit_weights(self):for m in self.modules():ifisinstance(m, nn.Conv2d):
  15. init.kaiming_normal_(m.weight, mode='fan_out')if m.bias isnotNone:
  16. init.constant_(m.bias,0)elifisinstance(m, nn.BatchNorm2d):
  17. init.constant_(m.weight,1)
  18. init.constant_(m.bias,0)elifisinstance(m, nn.Linear):
  19. init.normal_(m.weight, std=0.001)if m.bias isnotNone:
  20. init.constant_(m.bias,0)@staticmethoddefchannel_shuffle(x, groups):
  21. b, c, h, w = x.shape
  22. x = x.reshape(b, groups,-1, h, w)
  23. x = x.permute(0,2,1,3,4)# flatten
  24. x = x.reshape(b,-1, h, w)return x
  25. defforward(self, x):
  26. b, c, h, w = x.size()# group into subfeatures
  27. x = x.view(b * self.G,-1, h, w)# bs*G,c//G,h,w# channel_split
  28. x_0, x_1 = x.chunk(2, dim=1)# bs*G,c//(2*G),h,w# channel attention
  29. x_channel = self.avg_pool(x_0)# bs*G,c//(2*G),1,1
  30. x_channel = self.cweight * x_channel + self.cbias # bs*G,c//(2*G),1,1
  31. x_channel = x_0 * self.sigmoid(x_channel)# spatial attention
  32. x_spatial = self.gn(x_1)# bs*G,c//(2*G),h,w
  33. x_spatial = self.sweight * x_spatial + self.sbias # bs*G,c//(2*G),h,w
  34. x_spatial = x_1 * self.sigmoid(x_spatial)# bs*G,c//(2*G),h,w# concatenate along channel axis
  35. out = torch.cat([x_channel, x_spatial], dim=1)# bs*G,c//G,h,w
  36. out = out.contiguous().view(b,-1, h, w)# channel shuffle
  37. out = self.channel_shuffle(out,2)return out
  38. if __name__ =='__main__':input= torch.randn(50,512,7,7)
  39. se = ShuffleAttention(channel=512, G=8)
  40. output = se(input)
1.3 TripletAttention 模块代码
  1. TripletAttention

适合需要捕捉多方向特征的场景;在通道上引入了三个方向的注意力(水平、垂直、深度),能够帮助模型更好地感知多方向上的特征。这种机制特别适用于那些需要捕捉方向性信息的任务,比如道路标志检测和自然场景理解。

  1. import torch
  2. import torch.nn as nn
  3. classBasicConv(nn.Module):# https://arxiv.org/pdf/2010.03045.pdfdef__init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
  4. bn=True, bias=False):super(BasicConv, self).__init__()
  5. self.out_channels = out_planes
  6. self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
  7. dilation=dilation, groups=groups, bias=bias)
  8. self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)if bn elseNone
  9. self.relu = nn.ReLU()if relu elseNonedefforward(self, x):
  10. x = self.conv(x)if self.bn isnotNone:
  11. x = self.bn(x)if self.relu isnotNone:
  12. x = self.relu(x)return x
  13. classZPool(nn.Module):defforward(self, x):return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)classAttentionGate(nn.Module):def__init__(self):super(AttentionGate, self).__init__()
  14. kernel_size =7
  15. self.compress = ZPool()
  16. self.conv = BasicConv(2,1, kernel_size, stride=1, padding=(kernel_size -1)//2, relu=False)defforward(self, x):
  17. x_compress = self.compress(x)
  18. x_out = self.conv(x_compress)
  19. scale = torch.sigmoid_(x_out)return x * scale
  20. classTripletAttention(nn.Module):def__init__(self, no_spatial=False):super(TripletAttention, self).__init__()
  21. self.cw = AttentionGate()
  22. self.hc = AttentionGate()
  23. self.no_spatial = no_spatial
  24. ifnot no_spatial:
  25. self.hw = AttentionGate()defforward(self, x):
  26. x_perm1 = x.permute(0,2,1,3).contiguous()
  27. x_out1 = self.cw(x_perm1)
  28. x_out11 = x_out1.permute(0,2,1,3).contiguous()
  29. x_perm2 = x.permute(0,3,2,1).contiguous()
  30. x_out2 = self.hc(x_perm2)
  31. x_out21 = x_out2.permute(0,3,2,1).contiguous()ifnot self.no_spatial:
  32. x_out = self.hw(x)
  33. x_out =1/3*(x_out + x_out11 + x_out21)else:
  34. x_out =1/2*(x_out11 + x_out21)return x_out
1.4 MHSA 模块代码
  1. MHSA

(Multi-Head Self-Attention) 是常用于Transformer模型的注意力机制,适合大规模上下文建模的场景;通过多头自注意力的机制,能够帮助模型捕捉图像中的长距离依赖关系。它在需要处理上下文信息的任务中表现出色,如自然场景中的多物体检测。对于需要全局信息并且图像内物体之间具有复杂相互关系的任务,MHSA 是理想的选择。

  1. import torch
  2. import torch.nn as nn
  3. classMHSA(nn.Module):def__init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):super(MHSA, self).__init__()
  4. self.heads = heads
  5. self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
  6. self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
  7. self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
  8. self.pos = pos_emb
  9. if self.pos:
  10. self.rel_h_weight = nn.Parameter(torch.randn([1, heads,(n_dims)// heads,1,int(height)]),
  11. requires_grad=True)
  12. self.rel_w_weight = nn.Parameter(torch.randn([1, heads,(n_dims)// heads,int(width),1]),
  13. requires_grad=True)
  14. self.softmax = nn.Softmax(dim=-1)defforward(self, x):
  15. n_batch, C, width, height = x.size()
  16. q = self.query(x).view(n_batch, self.heads, C // self.heads,-1)
  17. k = self.key(x).view(n_batch, self.heads, C // self.heads,-1)
  18. v = self.value(x).view(n_batch, self.heads, C // self.heads,-1)
  19. content_content = torch.matmul(q.permute(0,1,3,2), k)# 1,C,h*w,h*w
  20. c1, c2, c3, c4 = content_content.size()if self.pos:
  21. content_position =(self.rel_h_weight + self.rel_w_weight).view(1, self.heads, C // self.heads,-1).permute(0,1,3,2)# 1,4,1024,64
  22. content_position = torch.matmul(content_position, q)# ([1, 4, 1024, 256])
  23. content_position = content_position if(
  24. content_content.shape == content_position.shape)else content_position[:,:,:c3,]assert(content_content.shape == content_position.shape)
  25. energy = content_content + content_position
  26. else:
  27. energy = content_content
  28. attention = self.softmax(energy)
  29. out = torch.matmul(v, attention.permute(0,1,3,2))# 1,4,256,64
  30. out = out.view(n_batch, C, width, height)return out
  31. # if __name__ == '__main__':# input = torch.randn(50, 512, 7, 7)# mhsa = MHSA(n_dims=512)# output = mhsa(input)# print(output.shape)
1.5 CBAM 模块代码
  1. CBAM

(Convolutional Block Attention Module)适合需要结合通道和空间特征的场景;通过结合通道注意力和空间注意力,帮助网络更加精准地捕捉图像中的关键区域。它适用于大多数目标检测任务,特别是当需要细化某些特定物体的检测时,比如在自动驾驶中的行人检测或交通标志检测。

  1. import torch
  2. from torch import nn
  3. classChannelAttention(nn.Module):# Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdetdef__init__(self, channels:int)->None:super().__init__()
  4. self.pool = nn.AdaptiveAvgPool2d(1)
  5. self.fc = nn.Conv2d(channels, channels,1,1,0, bias=True)
  6. self.act = nn.Sigmoid()defforward(self, x: torch.Tensor)-> torch.Tensor:return x * self.act(self.fc(self.pool(x)))classSpatialAttention(nn.Module):# Spatial-attention moduledef__init__(self, kernel_size=7):super().__init__()assert kernel_size in(3,7),'kernel size must be 3 or 7'
  7. padding =3if kernel_size ==7else1
  8. self.cv1 = nn.Conv2d(2,1, kernel_size, padding=padding, bias=False)
  9. self.act = nn.Sigmoid()defforward(self, x):return x * self.act(self.cv1(torch.cat([torch.mean(x,1, keepdim=True), torch.max(x,1, keepdim=True)[0]],1)))classCBAM(nn.Module):# Convolutional Block Attention Moduledef__init__(self, c1, kernel_size=7):# ch_in, kernelssuper().__init__()
  10. self.channel_attention = ChannelAttention(c1)
  11. self.spatial_attention = SpatialAttention(kernel_size)defforward(self, x):return self.spatial_attention(self.channel_attention(x))
1.6 EMA 模块代码
  1. EMA

(Efficient Multi-Head Attention)适合希望在多头自注意力中提升效率的场景;它通过减少计算复杂度而提升性能,适用于大规模数据集的训练。它在保持注意力机制强大的特征捕捉能力的同时,还能显著降低计算成本,适合高性能要求的任务场景。

  1. import torch
  2. from torch import nn
  3. classEMA(nn.Module):def__init__(self, channels, c2=None, factor=32):super(EMA, self).__init__()
  4. self.groups = factor
  5. assert channels // self.groups >0
  6. self.softmax = nn.Softmax(-1)
  7. self.agp = nn.AdaptiveAvgPool2d((1,1))
  8. self.pool_h = nn.AdaptiveAvgPool2d((None,1))
  9. self.pool_w = nn.AdaptiveAvgPool2d((1,None))
  10. self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
  11. self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
  12. self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)defforward(self, x):
  13. b, c, h, w = x.size()
  14. group_x = x.reshape(b * self.groups,-1, h, w)# b*g,c//g,h,w
  15. x_h = self.pool_h(group_x)
  16. x_w = self.pool_w(group_x).permute(0,1,3,2)
  17. hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
  18. x_h, x_w = torch.split(hw,[h, w], dim=2)
  19. x1 = self.gn(group_x * x_h.sigmoid()* x_w.permute(0,1,3,2).sigmoid())
  20. x2 = self.conv3x3(group_x)
  21. x11 = self.softmax(self.agp(x1).reshape(b * self.groups,-1,1).permute(0,2,1))
  22. x12 = x2.reshape(b * self.groups, c // self.groups,-1)# b*g, c//g, hw
  23. x21 = self.softmax(self.agp(x2).reshape(b * self.groups,-1,1).permute(0,2,1))
  24. x22 = x1.reshape(b * self.groups, c // self.groups,-1)# b*g, c//g, hw
  25. weights =(torch.matmul(x11, x12)+ torch.matmul(x21, x22)).reshape(b * self.groups,1, h, w)return(group_x * weights.sigmoid()).reshape(b, c, h, w)
1.7 ECA 模块代码
  1. ECA

(Efficient Channel Attention)适合需要高效通道注意力的场景;通过消除全连接层,使用1D卷积进行局部交互,大大减少了参数量,同时仍然保留了通道注意力的能力。它适合那些对计算资源有限制的场景,比如移动设备上进行目标检测的任务。

  1. import torch
  2. from torch import nn
  3. classECA(nn.Module):def__init__(self, channels:int, k_size:int=3):super(ECA, self).__init__()
  4. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  5. self.conv = nn.Conv1d(1,1, kernel_size=k_size, padding=(k_size -1)//2, bias=False)
  6. self.sigmoid = nn.Sigmoid()defforward(self, x):# Apply global average pooling
  7. y = self.avg_pool(x)# Reshape and apply 1D convolution
  8. y = self.conv(y.squeeze(-1).transpose(-1,-2)).transpose(-1,-2).unsqueeze(-1)# Apply sigmoid activation and element-wise multiplicationreturn x * self.sigmoid(y)

2. 添加注意力机制的步骤

2.1 修改YOLOv8的配置文件

我们可以通过在YOLOv8配置文件中指定使用注意力机制,以下是如何在第10层加入注意力机制的配置示例,以ShuffleAttention注意力机制为例,用到哪个放开哪个:

  1. # Ultralytics YOLO 🚀, AGPL-3.0 license# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parametersnc:81# number of classesscales:# model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n:[0.33,0.25,1024]# YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPss:[0.33,0.50,1024]# YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPsm:[0.67,0.75,768]# YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPsl:[1.00,1.00,512]# YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx:[1.00,1.25,512]# YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbonebackbone:# [from, repeats, module, args]-[-1,1, Conv,[64,3,2]]# 0-P1/2-[-1,1, Conv,[128,3,2]]# 1-P2/4-[-1,3, C2f,[128,True]]-[-1,1, Conv,[256,3,2]]# 3-P3/8-[-1,6, C2f,[256,True]]-[-1,1, Conv,[512,3,2]]# 5-P4/16-[-1,6, C2f,[512,True]]-[-1,1, Conv,[1024,3,2]]# 7-P5/32-[-1,3, C2f,[1024,True]]-[-1,1, SPPF,[1024,5]]# 9# - [-1, 1, SimAM, [1024]] # 10-[-1,1, ShuffleAttention,[16,8]]# 10# - [-1, 1, TripletAttention, [1024]]# - [-1, 1, MHSA, [14, 14, 4]] # 10# - [-1, 1, CBAM, [1024]] # 10# - [-1, 1, EMA, [1024, 8]] # 10# - [-1, 1, ECA, [1024]] # 10# YOLOv8.0n headhead:-[-1,1, nn.Upsample,[None,2,"nearest"]]-[[-1,6],1, Concat,[1]]# cat backbone P4-[-1,3, C2f,[512]]# 13-[-1,1, nn.Upsample,[None,2,"nearest"]]-[[-1,4],1, Concat,[1]]# cat backbone P3-[-1,3, C2f,[256]]# 16 (P3/8-small)-[-1,1, Conv,[256,3,2]]-[[-1,13],1, Concat,[1]]# cat head P4-[-1,3, C2f,[512]]# 19 (P4/16-medium)-[-1,1, Conv,[512,3,2]]-[[-1,10],1, Concat,[1]]# cat head P5-[-1,3, C2f,[1024]]# 22 (P5/32-large)-[[16,19,22],1, Detect,[nc]]# Detect(P3, P4, P5) ShuffleAttention
2.2 编写自定义注意力机制模块

为了在YOLOv8中集成自定义的注意力机制模块,我们需要将相应的注意力机制代码文件(如

  1. SimAM.py

  1. ShuffleAttention.py

  1. TripletAttention.py

  1. MHSA.py

  1. CBAM.py

  1. EMA.py

  1. ECA.py

)放入项目中的

  1. ultralytics/nn/

目录下,并在

  1. task.py

文件中进行相应的导入和修改。

步骤 1:将注意力机制模块文件放入
  1. ultralytics/nn/

目录

首先,确保所有注意力机制模块的代码文件都放在

  1. ultralytics/nn/

目录下。注意力机制模块的文件名如下:

  • SimAM.py
  • ShuffleAttention.py
  • TripletAttention.py
  • MHSA.py
  • CBAM.py
  • EMA.py
  • ECA.py

如图是我的目录:
在这里插入图片描述

步骤 2:在
  1. task.py

中导入注意力机制模块

接下来,在

  1. ultralytics/nn/tasks.py

文件的头部导入这些注意力机制模块。添加以下代码:

  1. from ultralytics.nn.MHSA import MHSA
  2. from ultralytics.nn.ShuffleAttention import ShuffleAttention
  3. from ultralytics.nn.SimAM import SimAM
  4. from ultralytics.nn.CBAM import CBAM
  5. from ultralytics.nn.TripletAttention import TripletAttention
  6. from ultralytics.nn.EMA import EMA
  7. from ultralytics.nn.ECA import ECA
步骤 3:在
  1. task.py

的 910 行到950行任意位置添加注意力机制的构建逻辑

  1. task.py

文件中的约 910 行到950行的位置,添加代码逻辑以确保这些注意力机制模块在网络中正确初始化。可以参考以下代码:

  1. ## 添加注意力机制elif m in{MHSA, ShuffleAttention}:
  2. args =[ch[f],*args]elif m in{EMA}:
  3. args =[ch[f]]elif m in(SimAM, CBAM, TripletAttention, ECA):
  4. c1, c2 = ch[f], args[0]if c2 != nc:
  5. c2 = make_divisible(min(c2, max_channels)* width,8)
  6. args =[c1,*args[1:]]

如图是我添加的位置:
在这里插入图片描述

此代码块确保在网络结构中正确处理和初始化各类注意力机制。

小结

通过上述步骤,我们可以成功将自定义的注意力机制集成到YOLOv8中。在

  1. task.py

中导入相应模块并添加初始化逻辑后,您可以通过修改配置文件来选择使用哪种注意力机制,并在训练过程中验证其效果。

2.3 训练和验证注意力机制

使用修改后的配置文件进行训练(我修改后的配置文件是yolov8_att.yaml),运行训练代码时候用yolov8s_att.yaml,其中的yolov8后面加了一个s,证明使用s网络训练;查看日志输出,能够看到类似如下红框的输出,证明

  1. ShuffleAttention

已成功加载,其他注意力机制类似:

  1. from ultralytics import YOLO
  2. if __name__ =='__main__':
  3. model = YOLO("yolov8s_att.yaml").load("yolov8s.pt")# build from YAML and transfer weights# Train the model
  4. results = model.train(data=r"traffic_signage.yaml", epochs=150, batch=16, imgsz=1280)

在这里插入图片描述


如何将注意力机制应用到项目中?

在实际项目中,不同的注意力机制具有不同的侧重点,适用于不同的任务场景。根据项目的具体需求,选择合适的注意力机制可以大大提升模型的表现。:

  1. 轻量模型优化:如果你的项目目标是提升模型的实时性并在嵌入式设备上运行,可以选择 SimAMECA 这样的轻量级注意力机制。
  2. 大规模场景建模:对于涉及复杂背景的目标检测任务,MHSACBAM 可以有效增强模型对全局特征的捕捉能力。
  3. 多目标检测:当你的项目涉及复杂的多目标检测场景(如自动驾驶中的多物体检测),可以尝试使用 ShuffleAttentionTripletAttention,它们在捕捉多方向和通道交互特征方面表现出色。
  4. 计算效率优化:如果你关注模型的计算效率,可以选择 EMA 来在保持高性能的同时减少计算开销。

通过合理选择和应用这些注意力机制,能够为目标检测、分类和语义分割等任务带来显著的性能提升。

3. 总结

本文详细介绍了六种常见的注意力机制,并展示了如何将它们集成到YOLOv8网络中。通过添加注意力机制,模型可以更好地捕捉图像中的重要信息,从而提升目标检测性能。

标签: YOLO

本文转载自: https://blog.csdn.net/qq_39045712/article/details/142456611
版权归原作者 一花·一树 所有, 如有侵权,请联系我们删除。

“如何在YOLOv8网络中添加自定义注意力机制”的评论:

还没有评论