0


语义分割系列7-Attention Unet(pytorch实现)

继前文Unet和Unet++之后,本文将介绍Attention Unet。

Attention Unet地址,《Attention U-Net: Learning Where to Look for the Pancreas》。


AttentionUnet

Attention Unet发布于2018年,主要应用于医学领域的图像分割,全文中主要以肝脏的分割论证。

论文中心

Attention Unet主要的中心思想就是提出来Attention gate模块,使用soft-attention替代hard-attention,将attention集成到Unet的跳跃连接和上采样模块中,实现空间上的注意力机制。通过attention机制来抑制图像中的无关信息,突出局部的重要特征。

网络架构


图1 AttentionUnet模型

Attention Unet的模型结构和Unet十分相像,只是增加了Attention Gate模块来对skip connection和upsampling层做attention机制(图2)。


图2 Attention Gate模块

在Attention Gate模块中,g和xl分别为skip connection的输出和下一层的输出,如图3。


图3 Attention Gate的输入

需要注意的是,在计算Wg和Wx后,对两者进行相加。但是,此时g的维度和xl的维度并不相等,则需要对g做下采样或对xl做上采样。(我倾向于对xl做上采样,因为在原本的Unet中,在Decoder就需要对下一层做上采样,所以,直接使用这个上采样结果可以减少网络计算)。

Wg和Wx经过相加,ReLU激活,1x1x1卷积,Sigmoid激活,生成一个权重信息,将这个权重与原始输入xl相乘,得到了对xl的attention激活。这就是Attenton Gate的思想。

Attenton Gate还有一个比较重要的特点是:这个权重可以经由网络学习!因为soft-attention是可微的,可以微分的attention就可以通过神经网络算出梯度并且前向传播和后向反馈来学习得到attention的权重。以此来学习更重要的特征。

模型复现

Attention Unet代码

  1. import torch
  2. import torch.nn as nn
  3. #Attention gate代码
  4. class AttentionBlock(nn.Module):
  5. def __init__(self, F_g, F_l, F_int):
  6. super(AttentionBlock, self).__init__()
  7. self.W_g = nn.Sequential(
  8. nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=False),
  9. nn.BatchNorm2d(F_int)
  10. )
  11. self.W_x = nn.Sequential(
  12. nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=False),
  13. nn.BatchNorm2d(F_int)
  14. )
  15. self.psi = nn.Sequential(
  16. nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=False),
  17. nn.BatchNorm2d(1),
  18. nn.Sigmoid()
  19. )
  20. self.relu = nn.ReLU(inplace=True)
  21. def forward(self, g, x):
  22. g = self.W_g(g)
  23. x = self.W_x(x)
  24. psi = self.relu(g+x)
  25. psi = self.psi(psi)
  26. return x*psi
  27. #AttentionUnet代码
  28. class AttentionUnet(nn.Module):
  29. def __init__(self, num_classes):
  30. super(AttentionUnet, self).__init__()
  31. self.stage_1 = nn.Sequential(
  32. nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3,padding=1),
  33. nn.BatchNorm2d(32),
  34. nn.ReLU(),
  35. nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3,padding=1),
  36. nn.BatchNorm2d(64),
  37. nn.ReLU(),
  38. nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,padding=1),
  39. nn.BatchNorm2d(64),
  40. nn.ReLU(),
  41. )
  42. self.stage_2 = nn.Sequential(
  43. nn.MaxPool2d(kernel_size=2),
  44. nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,padding=1),
  45. nn.BatchNorm2d(128),
  46. nn.ReLU(),
  47. nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1),
  48. nn.BatchNorm2d(128),
  49. nn.ReLU(),
  50. )
  51. self.stage_3 = nn.Sequential(
  52. nn.MaxPool2d(kernel_size=2),
  53. nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3,padding=1),
  54. nn.BatchNorm2d(256),
  55. nn.ReLU(),
  56. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,padding=1),
  57. nn.BatchNorm2d(256),
  58. nn.ReLU(),
  59. )
  60. self.stage_4 = nn.Sequential(
  61. nn.MaxPool2d(kernel_size=2),
  62. nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3,padding=1),
  63. nn.BatchNorm2d(512),
  64. nn.ReLU(),
  65. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1),
  66. nn.BatchNorm2d(512),
  67. nn.ReLU(),
  68. )
  69. self.stage_5 = nn.Sequential(
  70. nn.MaxPool2d(kernel_size=2),
  71. nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3,padding=1),
  72. nn.BatchNorm2d(1024),
  73. nn.ReLU(),
  74. nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3,padding=1),
  75. nn.BatchNorm2d(1024),
  76. nn.ReLU(),
  77. )
  78. self.upsample_4 = nn.Sequential(
  79. nn.ConvTranspose2d(in_channels=1024, out_channels=512,kernel_size=4,stride=2, padding=1)
  80. )
  81. self.upsample_3 = nn.Sequential(
  82. nn.ConvTranspose2d(in_channels=512, out_channels=256,kernel_size=4,stride=2, padding=1)
  83. )
  84. self.upsample_2 = nn.Sequential(
  85. nn.ConvTranspose2d(in_channels=256, out_channels=128,kernel_size=4,stride=2, padding=1)
  86. )
  87. self.upsample_1 = nn.Sequential(
  88. nn.ConvTranspose2d(in_channels=128, out_channels=64,kernel_size=4,stride=2, padding=1)
  89. )
  90. self.stage_up_4 = nn.Sequential(
  91. nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3,padding=1),
  92. nn.BatchNorm2d(512),
  93. nn.ReLU(),
  94. nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1),
  95. nn.BatchNorm2d(512),
  96. nn.ReLU()
  97. )
  98. self.stage_up_3 = nn.Sequential(
  99. nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3,padding=1),
  100. nn.BatchNorm2d(256),
  101. nn.ReLU(),
  102. nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,padding=1),
  103. nn.BatchNorm2d(256),
  104. nn.ReLU()
  105. )
  106. self.stage_up_2 = nn.Sequential(
  107. nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3,padding=1),
  108. nn.BatchNorm2d(128),
  109. nn.ReLU(),
  110. nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1),
  111. nn.BatchNorm2d(128),
  112. nn.ReLU()
  113. )
  114. self.stage_up_1 = nn.Sequential(
  115. nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3,padding=1),
  116. nn.BatchNorm2d(64),
  117. nn.ReLU(),
  118. nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,padding=1),
  119. nn.BatchNorm2d(64),
  120. nn.ReLU()
  121. )
  122. self.Attentiongate1 = AttentionBlock(512, 512, 512)
  123. self.Attentiongate2 = AttentionBlock(256, 256, 256)
  124. self.Attentiongate3 = AttentionBlock(128, 128, 128)
  125. self.final = nn.Sequential(
  126. nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, padding=1),
  127. )
  128. def forward(self, x):
  129. x = x.float()
  130. #下采样过程
  131. stage_1 = self.stage_1(x)
  132. stage_2 = self.stage_2(stage_1)
  133. stage_3 = self.stage_3(stage_2)
  134. stage_4 = self.stage_4(stage_3)
  135. stage_5 = self.stage_5(stage_4)
  136. up_4 = self.upsample_4(stage_5)
  137. stage_4 = self.Attentiongate1(up_4, stage_4)
  138. up_4_conv = self.stage_up_4(torch.cat([up_4, stage_4], dim=1))
  139. up_3 = self.upsample_3(up_4_conv)
  140. stage_3 = self.Attentiongate2(up_3, stage_3)
  141. up_3_conv = self.stage_up_3(torch.cat([up_3, stage_3], dim=1))
  142. up_2 = self.upsample_2(up_3_conv)
  143. stage_2 = self.Attentiongate3(up_2, stage_2)
  144. up_2_conv = self.stage_up_2(torch.cat([up_2, stage_2], dim=1))
  145. up_1 = self.upsample_1(up_2_conv)
  146. up_1_conv = self.stage_up_1(torch.cat([up_1, stage_1], dim=1))
  147. output = self.final(up_1_conv)
  148. return output

数据集

数据集依旧使用Camvid数据集,见Camvid数据集的构建和使用。

  1. # 导入库
  2. import os
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. import torch.nn.functional as F
  8. from torch import optim
  9. from torch.utils.data import Dataset, DataLoader, random_split
  10. from tqdm import tqdm
  11. import warnings
  12. warnings.filterwarnings("ignore")
  13. import os.path as osp
  14. import matplotlib.pyplot as plt
  15. from PIL import Image
  16. import numpy as np
  17. import albumentations as A
  18. from albumentations.pytorch.transforms import ToTensorV2
  19. torch.manual_seed(17)
  20. # 自定义数据集CamVidDataset
  21. class CamVidDataset(torch.utils.data.Dataset):
  22. """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
  23. Args:
  24. images_dir (str): path to images folder
  25. masks_dir (str): path to segmentation masks folder
  26. class_values (list): values of classes to extract from segmentation mask
  27. augmentation (albumentations.Compose): data transfromation pipeline
  28. (e.g. flip, scale, etc.)
  29. preprocessing (albumentations.Compose): data preprocessing
  30. (e.g. noralization, shape manipulation, etc.)
  31. """
  32. def __init__(self, images_dir, masks_dir):
  33. self.transform = A.Compose([
  34. A.Resize(224, 224),
  35. A.HorizontalFlip(),
  36. A.VerticalFlip(),
  37. A.Normalize(),
  38. ToTensorV2(),
  39. ])
  40. self.ids = os.listdir(images_dir)
  41. self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
  42. self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
  43. def __getitem__(self, i):
  44. # read data
  45. image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
  46. mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
  47. image = self.transform(image=image,mask=mask)
  48. return image['image'], image['mask'][:,:,0]
  49. def __len__(self):
  50. return len(self.ids)
  51. # 设置数据集路径
  52. DATA_DIR = r'dataset\camvid' # 根据自己的路径来设置
  53. x_train_dir = os.path.join(DATA_DIR, 'train_images')
  54. y_train_dir = os.path.join(DATA_DIR, 'train_labels')
  55. x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
  56. y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
  57. train_dataset = CamVidDataset(
  58. x_train_dir,
  59. y_train_dir,
  60. )
  61. val_dataset = CamVidDataset(
  62. x_valid_dir,
  63. y_valid_dir,
  64. )
  65. train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,drop_last=True)
  66. val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True,drop_last=True)

模型训练

  1. model = AttentionUnet(num_classes=33).cuda()
  2. #model.load_state_dict(torch.load(r"checkpoints/Unet_100.pth"),strict=False)
  3. from d2l import torch as d2l
  4. from tqdm import tqdm
  5. import pandas as pd
  6. #损失函数选用多分类交叉熵损失函数
  7. lossf = nn.CrossEntropyLoss(ignore_index=255)
  8. #选用adam优化器来训练
  9. optimizer = optim.SGD(model.parameters(),lr=0.1)
  10. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1, last_epoch=-1)
  11. #训练50轮
  12. epochs_num = 50
  13. def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
  14. devices=d2l.try_all_gpus()):
  15. timer, num_batches = d2l.Timer(), len(train_iter)
  16. animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
  17. legend=['train loss', 'train acc', 'test acc'])
  18. net = nn.DataParallel(net, device_ids=devices).to(devices[0])
  19. loss_list = []
  20. train_acc_list = []
  21. test_acc_list = []
  22. epochs_list = []
  23. time_list = []
  24. for epoch in range(num_epochs):
  25. # Sum of training loss, sum of training accuracy, no. of examples,
  26. # no. of predictions
  27. metric = d2l.Accumulator(4)
  28. for i, (features, labels) in enumerate(train_iter):
  29. timer.start()
  30. l, acc = d2l.train_batch_ch13(
  31. net, features, labels.long(), loss, trainer, devices)
  32. metric.add(l, acc, labels.shape[0], labels.numel())
  33. timer.stop()
  34. if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
  35. animator.add(epoch + (i + 1) / num_batches,
  36. (metric[0] / metric[2], metric[1] / metric[3],
  37. None))
  38. test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
  39. animator.add(epoch + 1, (None, None, test_acc))
  40. scheduler.step()
  41. # print(f'loss {metric[0] / metric[2]:.3f}, train acc '
  42. # f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
  43. # print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
  44. # f'{str(devices)}')
  45. print(f"epoch {epoch+1} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
  46. #---------保存训练数据---------------
  47. df = pd.DataFrame()
  48. loss_list.append(metric[0] / metric[2])
  49. train_acc_list.append(metric[1] / metric[3])
  50. test_acc_list.append(test_acc)
  51. epochs_list.append(epoch+1)
  52. time_list.append(timer.sum())
  53. df['epoch'] = epochs_list
  54. df['loss'] = loss_list
  55. df['train_acc'] = train_acc_list
  56. df['test_acc'] = test_acc_list
  57. df['time'] = time_list
  58. df.to_excel("savefile/AttentionUnet_camvid1.xlsx")
  59. #----------------保存模型-------------------
  60. if np.mod(epoch+1, 5) == 0:
  61. torch.save(model.state_dict(), f'checkpoints/AttentionUnet_{epoch+1}.pth')

开始训练

  1. train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)

训练结果


本文转载自: https://blog.csdn.net/yumaomi/article/details/124866235
版权归原作者 yumaomi 所有, 如有侵权,请联系我们删除。

“语义分割系列7-Attention Unet(pytorch实现)”的评论:

还没有评论