0


语义分割系列11-DAnet(pytorch实现)

DAnet:Dual Attention Network for Scene Segmentation

发布于CVPR2019,本文将进行DAnet的论文讲解和复现工作。

论文部分

主要思想

DAnet的思想并没有之前提到的DFAnet那么花里胡哨,需要各种多层次的连接,DAnet的主要思想就是——同时引入了空间注意力和通道注意力,也就是Dual Attention = Channel Attention + Position Attention。

其中,Position Attention可以在位置上,捕捉任意两个位置之间的上下文信息,而Channel Attention可以捕捉通道维度上的上下文信息

关于Position Attention:较为通俗的解释是,所有的位置,两两之间都有一个权重γ,这个γ的值由两个位置之间的相似性来决定,而不是由两个位置的距离来决定,这就提供了一个好处,也就是——无论两个位置距离多远,只要他们相似度高,空间注意力机制就可以锁定这两个位置。

关于Channel Attention:在高级语义特征中,每一个通道都可以被认为是对于某一个类的特殊响应,增强拥有这种响应的特征通道可以有效的提高分割效果。而通道注意力在EncNet和DFAnet中都有应用,通过计算一个权重因子,对每个通道进行加权,突出重要的通道,增强特征表示。

作者的一些观点

  1. 关于为什么需要Attention机制,作者认为,在卷积的过程中,导致感受野局限在某一范围,而这种操作导致相同类别的像素之间产生一定的差异,这会导致识别上准确率降低的问题。
  2. 与大部分作者相同,在文中作者也对ResNet的最后几层做了一些改动,加入空洞卷积,将原先ResNet下采样速率从32倍降低到8倍,也就是ResNet最后一层输出的特征图大小为原始输入的1/8。这样子做的好处就是保留了更多的细节信息,毕竟下采样过多倍速以后细节容易丢失。

模型部分

DAnet主要的部分是通道注意力和空间注意力的实现,模型如图1。


图1 DAnet

Position attention module


图2 Position attention

对于空间注意力的实现,首先将特征图A(C×H×W)输入到卷积模块中,生成B(C×H×W)和C(C×H×W),将B和C reshape成(C×N)维度,其中N=H×W,N就是像素点的个数。随后,将B矩阵转置后和C矩阵相乘,将结果输入到softmax中,得到一个空间注意力图S(N×N)。矩阵的乘法相当于让每一个像素点之间都产生了联系,也就是上文提到的两个位置之间的相似度γ。其中,两个位置相似度越高,Sji这个值就越大。

同样,A输入到另一个卷积层生成新的特征映射D(C×H×W),reshape成C×N后与上述的空间注意力图S的转置进行相乘,这样就得到了C×N大小的矩阵,再将这个矩阵reshape成原来的C×H×W大小。将这个矩阵乘以一个系数α(与前文提到的α不是同一个值),然后加上原始的特征图A。这样就实现了一个空间注意力机制(Position Attention)。需要注意的是,这个α值是可学习参数,初始化为0。

Channel Attention

Channel Attention机制的实现与Position Attention类似,但与DFAnet和EncNet中使用fc attention来实现Channel Attention的方式略微不同。


图3 Channel Attention

同样,特征图A(C×H×W)reshape成C×N的矩阵,分别经过转置、矩阵乘法、softmax到注意力图X(C×C)。

随后这个注意力图X与reshape成C×N的A矩阵进行矩阵乘法,得到的输出(C×N)再reshape成C×H×W和原始特征图A进行加权。

这里的β是一个可学习参数,初始化为0。

需要注意的是计算通道注意力时没有通过任何卷积层来嵌入特征,与Position attention实现上有一定差异。作者的解释是,这样可以保留原始通道之间的关系。

结果部分

作者在Cityscapes、PASCAL VOC等数据集上都做了一些测试,来证明DAnet的优越性。同时呢,作者也对Attention的两个模块做了一个可视化,见图4。


图4 attention可视化

可以看到,Attention模块确实能够看到一些重要的信息,比如车、树等。效果也确实很好。

模型复现

DAnet网络

主干网络resnet50

  1. from torchvision.models import resnet50, resnet101
  2. from torchvision.models._utils import IntermediateLayerGetter
  3. import torch
  4. import torch.nn as nn
  5. backbone=IntermediateLayerGetter(
  6. resnet50(pretrained=False, replace_stride_with_dilation=[False, True, True]),
  7. return_layers={'layer4': 'stage4'}
  8. )
  9. # test
  10. x = torch.randn(3, 3, 224, 224).cpu()
  11. result = backbone(x)
  12. for k, v in result.items():
  13. print(k, v.shape)

DAHead

  1. class PositionAttention(nn.Module):
  2. def __init__(self, in_channels):
  3. super(PositionAttention, self).__init__()
  4. self.convB = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
  5. self.convC = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
  6. self.convD = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
  7. #创建一个可学习参数a作为权重,并初始化为0.
  8. self.gamma = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
  9. self.gamma.data.fill_(0.)
  10. self.softmax = nn.Softmax(dim=2)
  11. def forward(self, x):
  12. b,c,h,w = x.size()
  13. B = self.convB(x)
  14. C = self.convB(x)
  15. D = self.convB(x)
  16. S = self.softmax(torch.matmul(B.view(b, c, h*w).transpose(1, 2), C.view(b, c, h*w)))
  17. E = torch.matmul(D.view(b, c, h*w), S.transpose(1, 2)).view(b,c,h,w)
  18. #gamma is a parameter which can be training and iter
  19. E = self.gamma * E + x
  20. return E
  21. class ChannelAttention(nn.Module):
  22. def __init__(self):
  23. super(ChannelAttention, self).__init__()
  24. self.beta = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
  25. self.beta.data.fill_(0.)
  26. self.softmax = nn.Softmax(dim=2)
  27. def forward(self, x):
  28. b,c,h,w = x.size()
  29. X = self.softmax(torch.matmul(x.view(b, c, h*w), x.view(b, c, h*w).transpose(1, 2)))
  30. X = torch.matmul(X.transpose(1, 2), x.view(b, c, h*w)).view(b, c, h, w)
  31. X = self.beta * X + x
  32. return X
  33. class DAHead(nn.Module):
  34. def __init__(self, in_channels, num_classes):
  35. super(DAHead, self).__init__()
  36. self.conv1 = nn.Sequential(
  37. nn.Conv2d(in_channels, in_channels//4, kernel_size=3, padding=1, bias=False),
  38. nn.BatchNorm2d(in_channels//4),
  39. nn.ReLU(),
  40. )
  41. self.conv2 = nn.Sequential(
  42. nn.Conv2d(in_channels, in_channels//4, kernel_size=3, padding=1, bias=False),
  43. nn.BatchNorm2d(in_channels//4),
  44. nn.ReLU(),
  45. )
  46. self.conv3 = nn.Sequential(
  47. nn.Conv2d(in_channels//4, in_channels//4, kernel_size=3, padding=1, bias=False),
  48. nn.BatchNorm2d(in_channels//4),
  49. nn.ReLU(),
  50. )
  51. self.conv4 = nn.Sequential(
  52. nn.Conv2d(in_channels//4, in_channels//8, kernel_size=3, padding=1, bias=False),
  53. nn.BatchNorm2d(in_channels//8),
  54. nn.ReLU(),
  55. nn.Conv2d(in_channels//8, num_classes, kernel_size=3, padding=1, bias=False),
  56. )
  57. self.PositionAttention = PositionAttention(in_channels//4)
  58. self.ChannelAttention = ChannelAttention()
  59. def forward(self, x):
  60. x_PA = self.conv1(x)
  61. x_CA = self.conv2(x)
  62. PosionAttentionMap = self.PositionAttention(x_PA)
  63. ChannelAttentionMap = self.ChannelAttention(x_CA)
  64. #这里可以额外分别做PAM和CAM的卷积输出,分别对两个分支做一个上采样和预测,
  65. #可以生成一个cam loss和pam loss以及最终融合后的结果的loss.以及做一些可视化工作
  66. #这里只输出了最终的融合结果.与原文有一些出入.
  67. output = self.conv3(PosionAttentionMap + ChannelAttentionMap)
  68. output = nn.functional.interpolate(output, scale_factor=8, mode="bilinear",align_corners=True)
  69. output = self.conv4(output)
  70. return output

DAnet

  1. class DAnet(nn.Module):
  2. def __init__(self, num_classes):
  3. super(DAnet, self).__init__()
  4. self.ResNet50 = IntermediateLayerGetter(
  5. resnet50(pretrained=False, replace_stride_with_dilation=[False, True, True]),
  6. return_layers={'layer4': 'stage4'}
  7. )
  8. self.decoder = DAHead(in_channels=2048, num_classes=num_classes)
  9. def forward(self, x):
  10. feats = self.ResNet50(x)
  11. # self.ResNet50返回的是一个字典类型的数据.
  12. x = self.decoder(feats["stage4"])
  13. return x
  14. if __name__ == "__main__":
  15. x = torch.randn(3, 3, 224, 224).cpu()
  16. model = DAnet(num_classes=3)
  17. result = model(x)
  18. print(result.shape)

数据集-Camvid

数据集的创建和使用见教程:CamVid数据集的创建和使用

  1. # 导入库
  2. import os
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. import torch.nn.functional as F
  8. from torch import optim
  9. from torch.utils.data import Dataset, DataLoader, random_split
  10. from tqdm import tqdm
  11. import warnings
  12. warnings.filterwarnings("ignore")
  13. import os.path as osp
  14. import matplotlib.pyplot as plt
  15. from PIL import Image
  16. import numpy as np
  17. import albumentations as A
  18. from albumentations.pytorch.transforms import ToTensorV2
  19. torch.manual_seed(17)
  20. # 自定义数据集CamVidDataset
  21. class CamVidDataset(torch.utils.data.Dataset):
  22. """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
  23. Args:
  24. images_dir (str): path to images folder
  25. masks_dir (str): path to segmentation masks folder
  26. class_values (list): values of classes to extract from segmentation mask
  27. augmentation (albumentations.Compose): data transfromation pipeline
  28. (e.g. flip, scale, etc.)
  29. preprocessing (albumentations.Compose): data preprocessing
  30. (e.g. noralization, shape manipulation, etc.)
  31. """
  32. def __init__(self, images_dir, masks_dir):
  33. self.transform = A.Compose([
  34. A.Resize(224, 224),
  35. A.HorizontalFlip(),
  36. A.VerticalFlip(),
  37. A.Normalize(),
  38. ToTensorV2(),
  39. ])
  40. self.ids = os.listdir(images_dir)
  41. self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
  42. self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
  43. def __getitem__(self, i):
  44. # read data
  45. image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
  46. mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
  47. image = self.transform(image=image,mask=mask)
  48. return image['image'], image['mask'][:,:,0]
  49. def __len__(self):
  50. return len(self.ids)
  51. # 设置数据集路径
  52. DATA_DIR = r'dataset\camvid' # 根据自己的路径来设置
  53. x_train_dir = os.path.join(DATA_DIR, 'train_images')
  54. y_train_dir = os.path.join(DATA_DIR, 'train_labels')
  55. x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
  56. y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
  57. train_dataset = CamVidDataset(
  58. x_train_dir,
  59. y_train_dir,
  60. )
  61. val_dataset = CamVidDataset(
  62. x_valid_dir,
  63. y_valid_dir,
  64. )
  65. train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,drop_last=True)
  66. val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True,drop_last=True)

模型训练

  1. model = DAnet(num_classes=33).cuda()
  2. #model.load_state_dict(torch.load(r"checkpoints/resnet101-5d3b4d8f.pth"),strict=False)
  1. from d2l import torch as d2l
  2. from tqdm import tqdm
  3. import pandas as pd
  4. #损失函数选用多分类交叉熵损失函数
  5. lossf = nn.CrossEntropyLoss(ignore_index=255)
  6. #选用adam优化器来训练
  7. optimizer = optim.SGD(model.parameters(),lr=0.1)
  8. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1, last_epoch=-1)
  9. #训练50轮
  10. epochs_num = 50
  11. def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
  12. devices=d2l.try_all_gpus()):
  13. timer, num_batches = d2l.Timer(), len(train_iter)
  14. animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
  15. legend=['train loss', 'train acc', 'test acc'])
  16. net = nn.DataParallel(net, device_ids=devices).to(devices[0])
  17. loss_list = []
  18. train_acc_list = []
  19. test_acc_list = []
  20. epochs_list = []
  21. time_list = []
  22. for epoch in range(num_epochs):
  23. # Sum of training loss, sum of training accuracy, no. of examples,
  24. # no. of predictions
  25. metric = d2l.Accumulator(4)
  26. for i, (features, labels) in enumerate(train_iter):
  27. timer.start()
  28. l, acc = d2l.train_batch_ch13(
  29. net, features, labels.long(), loss, trainer, devices)
  30. metric.add(l, acc, labels.shape[0], labels.numel())
  31. timer.stop()
  32. if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
  33. animator.add(epoch + (i + 1) / num_batches,
  34. (metric[0] / metric[2], metric[1] / metric[3],
  35. None))
  36. test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
  37. animator.add(epoch + 1, (None, None, test_acc))
  38. scheduler.step()
  39. print(f"epoch {epoch+1} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
  40. #---------保存训练数据---------------
  41. df = pd.DataFrame()
  42. loss_list.append(metric[0] / metric[2])
  43. train_acc_list.append(metric[1] / metric[3])
  44. test_acc_list.append(test_acc)
  45. epochs_list.append(epoch+1)
  46. time_list.append(timer.sum())
  47. df['epoch'] = epochs_list
  48. df['loss'] = loss_list
  49. df['train_acc'] = train_acc_list
  50. df['test_acc'] = test_acc_list
  51. df['time'] = time_list
  52. df.to_excel("savefile/DAnet_camvid.xlsx")
  53. #----------------保存模型-------------------
  54. if np.mod(epoch+1, 5) == 0:
  55. torch.save(model.state_dict(), f'checkpoints/DAnet_{epoch+1}.pth')

开始训练

  1. train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)

训练结果


本文转载自: https://blog.csdn.net/yumaomi/article/details/124909672
版权归原作者 yumaomi 所有, 如有侵权,请联系我们删除。

“语义分割系列11-DAnet(pytorch实现)”的评论:

还没有评论