目录
前言
文章性质:学习笔记 📖
视频教程:使用 Pytorch 搭建 U-Net 网络并基于 DRIVE 数据集训练(语义分割)- 1 网络的搭建
主要内容:根据 视频教程 中提供的 U-Net 源代码(PyTorch),对 DRIVE 文件夹结构和 predict.py、unet.py 文件进行具体讲解。
Preparation
├── src: 搭建U-Net模型代码
├── train_utils: 训练、验证以及多GPU训练相关模块
├── my_dataset.py: 自定义dataset用于读取DRIVE数据集(视网膜血管分割)
├── train.py: 以单GPU为例进行训练
├── train_multi_GPU.py: 针对使用多GPU的用户使用
├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
└── compute_mean_std.py: 统计数据集各通道的均值和标准差
一、U-Net 网络结构图
原论文提供的 U-Net 网络结构图如下所示:
原论文中提供的 U-Net 网络结构所使用的卷积层会改变特征层的高和宽,而现在比较主流的方式是 不去改变输入特征层的高和宽 ,将转置卷积替换成简单的双线性插值进行上采样,所以霹雳吧啦重绘的 U-Net 网络结构图也是按照 双线性插值 进行绘制的,如下图所示:
二、U-Net 网络源代码
1、DRIVE 数据集
将 DRIVE 数据集下载下来后放在 unet 项目目录下,因为在 train.py 文件中将读取数据集的目录默认设置为当前目录,如下图所示:
我们再来简单看看 DRIVE 文件夹的结构,它主要分为 test 测试(验证)集和 training 训练集。
在 training 训练集文件夹中:
- 1st_manual 提供了人工分割好的标签图片
- images 提供了用于分割的原图片
- mask 提供了二值图片,白色部分是要分割的感兴趣的区域,mask 提供了类似于蒙版的效果
在 test 测试(验证)集文件夹中:
- 1st_manual 提供了人工分割好的标签图片,用于精标准
- 2nd_manual 提供了人工分割好的标签图片,用于分割做验证
- images 提供了用于分割的原图片
- mask 提供了二值图片,白色部分是要分割的感兴趣的区域,mask 提供了类似于蒙版的效果
【彩蛋】为了方便王子公主们下载 DRIVE 数据集,微臣将官网的下载地址贴在这里:Introduction - Grand Challenge
【补充】当然也可以前往【Preparation】中提供的 GitHub 地址,在霹雳吧啦提供的网盘链接中下载。
2、U-Net 源代码的不同
(1)train.py
这里的 train.py 训练脚本和之前讲过的 FCN 源代码中的训练脚本类似,不同之处在于 create_model 创建模型部分:只需简单调用 U-Net ,传入相应的参数后创建即可,不需要载入预训练权重。
(2)train_and_val.py
在模型的训练和验证过程中,我们在 criterion 函数中引入了 dice_loss ,在 evaluate 函数中增加了 dice 指标,这个后面会进行详细的讲解。
(3)results.txt
训练完成后会生成 results 文本文件,保存了每轮训练的训练损失 train_loss、学习率 lr、Dice 系数 dice coefficient、全局正确率 global correct、平均行正确率 average row correct、交并比 IoU 和平均交并比 mean IoU 值。
3、predict.py 模型预测
在 predict.py 文件的 main 函数中,设置类别的数量为 1 ,不包含背景类别。接着依次设置了训练好的模型权重文件路径 weights_path、输入图像路径 img_path、感兴趣区域的掩模路径 roi_mask_path,并对这些路径文件进行断言,检查是否存在,不存在则输出对应的错误信息。
运行 predict.py 文件完成网络预测后,将生成 test_result.png 图片:
4、unet.py 模型搭建
这个 unet.py 是 U-Net 网络搭建部分。
(1)DoubleConv 类
因为在 UNet 网络结构中,卷积层基本是成对出现的,因此构造了 DoubleConv 类,in_channels 是输入特征层的 channel ,out_channels 是通过 DoubleConv 后的输出特征层的 channel ,mid_channels 是通过第一个卷积层后的输出特征图的 channel 。
【说明】因为这里的卷积层采用的是比较主流的方式,即不去改变特征图的高和宽,因此 padding 设置为 1 。
(2)Down 类
Down 类继承自 nn.Sequential 父类,该模块由一个下采样和两个卷积构成(Encoder):
(3)Up 类
Up 类继承自 nn.Module 父类,该模块由一个上采样和两个卷积构成(Decoder):
【说明1】在 init 初始化函数中,传入 bilinear 表示是否使用双线性插值替代转置卷积,Up 类默认会使用双线性插值,故令 bilinear=True ,这里的 in_channels 对应 concat 拼接之后的 channels 或者说对应 Up 模块中第一个卷积的输入 channels 。
【说明2】在 forward 函数中, 绿 框部分的处理是为了确保要 concat 拼接的** x1 和 x2 **的宽高相同,因为当最初输入的特征图的宽高不是 16 的整数倍时,在下采样后需要进行取整,再进行上采样后可能会出现 尺寸对不上 的问题。
(4)OutConv 类
OutConv 类继承自 nn.Module 父类,对应 U-Net 网络结构的最后一个卷积层,其卷积核个数为包含背景的分类类别个数 num_classes 。
(5)UNet 类 ★
UNet 类继承自 nn.Module 父类,传入参数包括:
- in_channels 是输入特征图的通道数,彩色图片为 3 ,灰度图片为 1
- num_classes 是包含背景的分类类别个数
- bilinear 表示是否使用双线性插值法
- base_c 是基础通道数 channel
【说明】factor 是一个因子,用于控制上采样过程中特征图的通道数,这个因子的值取决于是否使用双线性插值的 bilinear 标志:
- 当 bilinear=True 时,factor = 2,说明在上采样过程中,特征图的通道数减半,当 base_c 为 64 时,上采样后特征图的通道数为 32
- 当 bilinear=False 时,factor = 1,说明在上采样过程中,特征图的通道数不变,当 base_c 为 64 时,上采样后特征图的通道数为 64
这个 factor 因子的引入主要是为了在上采用过程中控制特征图的大小和复杂度,以适应不同的任务需求和计算资源限制。
这是 U-Net 的 forward 前向传播函数,它接受一个 torch.Tensor 类型的输入 x ,并返回一个字典类型的输出。
5、unet.py 的源代码
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Sequential):
def __init__(self, in_channels, out_channels, mid_channels=None):
if mid_channels is None:
mid_channels = out_channels
super(DoubleConv, self).__init__(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
class Down(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(Down, self).__init__(
nn.MaxPool2d(2, stride=2),
DoubleConv(in_channels, out_channels)
)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super(Up, self).__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# [N, C, H, W]
diff_y = x2.size()[2] - x1.size()[2]
diff_x = x2.size()[3] - x1.size()[3]
# padding_left, padding_right, padding_top, padding_bottom
x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
diff_y // 2, diff_y - diff_y // 2])
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class OutConv(nn.Sequential):
def __init__(self, in_channels, num_classes):
super(OutConv, self).__init__(
nn.Conv2d(in_channels, num_classes, kernel_size=1)
)
class UNet(nn.Module):
def __init__(self,
in_channels: int = 1,
num_classes: int = 2,
bilinear: bool = True,
base_c: int = 64):
super(UNet, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.bilinear = bilinear
self.in_conv = DoubleConv(in_channels, base_c)
self.down1 = Down(base_c, base_c * 2)
self.down2 = Down(base_c * 2, base_c * 4)
self.down3 = Down(base_c * 4, base_c * 8)
factor = 2 if bilinear else 1
self.down4 = Down(base_c * 8, base_c * 16 // factor)
self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
self.up4 = Up(base_c * 2, base_c, bilinear)
self.out_conv = OutConv(base_c, num_classes)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
x1 = self.in_conv(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.out_conv(x)
return {"out": logits}
版权归原作者 作者正在煮茶 所有, 如有侵权,请联系我们删除。