0


【霹雳吧啦】手把手带你入门语义分割の番外8:U-Net 源码讲解(PyTorch)—— 网络的搭建

目录

前言

文章性质:学习笔记 📖

视频教程:使用 Pytorch 搭建 U-Net 网络并基于 DRIVE 数据集训练(语义分割)- 1 网络的搭建

主要内容:根据 视频教程 中提供的 U-Net 源代码(PyTorch),对 DRIVE 文件夹结构和 predict.py、unet.py 文件进行具体讲解。

Preparation

源代码:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_segmentation/unet

├── src: 搭建U-Net模型代码
├── train_utils: 训练、验证以及多GPU训练相关模块
├── my_dataset.py: 自定义dataset用于读取DRIVE数据集(视网膜血管分割)
├── train.py: 以单GPU为例进行训练
├── train_multi_GPU.py: 针对使用多GPU的用户使用
├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
└── compute_mean_std.py: 统计数据集各通道的均值和标准差

一、U-Net 网络结构图

原论文提供的 U-Net 网络结构图如下所示:

原论文中提供的 U-Net 网络结构所使用的卷积层会改变特征层的高和宽,而现在比较主流的方式是 不去改变输入特征层的高和宽 ,将转置卷积替换成简单的双线性插值进行上采样,所以霹雳吧啦重绘的 U-Net 网络结构图也是按照 双线性插值 进行绘制的,如下图所示:

二、U-Net 网络源代码

1、DRIVE 数据集

将 DRIVE 数据集下载下来后放在 unet 项目目录下,因为在 train.py 文件中将读取数据集的目录默认设置为当前目录,如下图所示:

我们再来简单看看 DRIVE 文件夹的结构,它主要分为 test 测试(验证)集和 training 训练集。

在 training 训练集文件夹中:

  • 1st_manual 提供了人工分割好的标签图片
  • images 提供了用于分割的原图片
  • mask 提供了二值图片,白色部分是要分割的感兴趣的区域,mask 提供了类似于蒙版的效果

在 test 测试(验证)集文件夹中:

  • 1st_manual 提供了人工分割好的标签图片,用于精标准
  • 2nd_manual 提供了人工分割好的标签图片,用于分割做验证
  • images 提供了用于分割的原图片
  • mask 提供了二值图片,白色部分是要分割的感兴趣的区域,mask 提供了类似于蒙版的效果

【彩蛋】为了方便王子公主们下载 DRIVE 数据集,微臣将官网的下载地址贴在这里:Introduction - Grand Challenge

【补充】当然也可以前往【Preparation】中提供的 GitHub 地址,在霹雳吧啦提供的网盘链接中下载。

2、U-Net 源代码的不同

(1)train.py

这里的 train.py 训练脚本和之前讲过的 FCN 源代码中的训练脚本类似,不同之处在于 create_model 创建模型部分:只需简单调用 U-Net ,传入相应的参数后创建即可,不需要载入预训练权重。

(2)train_and_val.py

在模型的训练和验证过程中,我们在 criterion 函数中引入了 dice_loss ,在 evaluate 函数中增加了 dice 指标,这个后面会进行详细的讲解。

(3)results.txt

训练完成后会生成 results 文本文件,保存了每轮训练的训练损失 train_loss、学习率 lr、Dice 系数 dice coefficient、全局正确率 global correct、平均行正确率 average row correct、交并比 IoU 和平均交并比 mean IoU 值。

3、predict.py 模型预测

在 predict.py 文件的 main 函数中,设置类别的数量为 1 ,不包含背景类别。接着依次设置了训练好的模型权重文件路径 weights_path、输入图像路径 img_path、感兴趣区域的掩模路径 roi_mask_path,并对这些路径文件进行断言,检查是否存在,不存在则输出对应的错误信息。

运行 predict.py 文件完成网络预测后,将生成 test_result.png 图片:

4、unet.py 模型搭建

这个 unet.py 是 U-Net 网络搭建部分。

(1)DoubleConv 类

因为在 UNet 网络结构中,卷积层基本是成对出现的,因此构造了 DoubleConv 类,in_channels 是输入特征层的 channel ,out_channels 是通过 DoubleConv 后的输出特征层的 channel ,mid_channels 是通过第一个卷积层后的输出特征图的 channel 。

【说明】因为这里的卷积层采用的是比较主流的方式,即不去改变特征图的高和宽,因此 padding 设置为 1 。

(2)Down 类

Down 类继承自 nn.Sequential 父类,该模块由一个下采样和两个卷积构成(Encoder):

(3)Up 类

Up 类继承自 nn.Module 父类,该模块由一个上采样和两个卷积构成(Decoder):

【说明1】init 初始化函数中,传入 bilinear 表示是否使用双线性插值替代转置卷积,Up 类默认会使用双线性插值,故令 bilinear=True ,这里的 in_channels 对应 concat 拼接之后的 channels 或者说对应 Up 模块中第一个卷积的输入 channels 。

【说明2】在 forward 函数中, 绿 框部分的处理是为了确保要 concat 拼接的** x1 x2 **的宽高相同,因为当最初输入的特征图的宽高不是 16 的整数倍时,在下采样后需要进行取整,再进行上采样后可能会出现 尺寸对不上 的问题。

(4)OutConv 类

OutConv 类继承自 nn.Module 父类,对应 U-Net 网络结构的最后一个卷积层,其卷积核个数为包含背景的分类类别个数 num_classes 。

(5)UNet 类 ★

UNet 类继承自 nn.Module 父类,传入参数包括:

  • in_channels 是输入特征图的通道数,彩色图片为 3 ,灰度图片为 1
  • num_classes 是包含背景的分类类别个数
  • bilinear 表示是否使用双线性插值法
  • base_c 是基础通道数 channel

【说明】factor 是一个因子,用于控制上采样过程中特征图的通道数,这个因子的值取决于是否使用双线性插值的 bilinear 标志:

  • 当 bilinear=True 时,factor = 2,说明在上采样过程中,特征图的通道数减半,当 base_c 为 64 时,上采样后特征图的通道数为 32
  • 当 bilinear=False 时,factor = 1,说明在上采样过程中,特征图的通道数不变,当 base_c 为 64 时,上采样后特征图的通道数为 64

这个 factor 因子的引入主要是为了在上采用过程中控制特征图的大小和复杂度,以适应不同的任务需求和计算资源限制。

这是 U-Net 的 forward 前向传播函数,它接受一个 torch.Tensor 类型的输入 x ,并返回一个字典类型的输出。

5、unet.py 的源代码

from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(DoubleConv, self).__init__(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

class Down(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__(
            nn.MaxPool2d(2, stride=2),
            DoubleConv(in_channels, out_channels)
        )

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        # [N, C, H, W]
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]

        # padding_left, padding_right, padding_top, padding_bottom
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

class OutConv(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__(
            nn.Conv2d(in_channels, num_classes, kernel_size=1)
        )

class UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 num_classes: int = 2,
                 bilinear: bool = True,
                 base_c: int = 64):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear

        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2)
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        factor = 2 if bilinear else 1
        self.down4 = Down(base_c * 8, base_c * 16 // factor)
        self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up(base_c * 2, base_c, bilinear)
        self.out_conv = OutConv(base_c, num_classes)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.out_conv(x)

        return {"out": logits}

本文转载自: https://blog.csdn.net/nanzhou520/article/details/135232999
版权归原作者 作者正在煮茶 所有, 如有侵权,请联系我们删除。

“【霹雳吧啦】手把手带你入门语义分割の番外8:U-Net 源码讲解(PyTorch)—— 网络的搭建”的评论:

还没有评论