0


Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理

目录

1 数据集Dataset

Dataset

类是Pytorch中图像数据集操作的核心类,Pytorch中所有数据集加载类都继承自

Dataset

父类。当我们自定义数据集处理时,必须实现

Dataset

类中的三个接口:

  • 初始化def__init__(self) 构造函数,定义一些数据集的公有属性,如数据集下载地址、名称等
  • 数据集大小def__len__(self) 返回数据集大小,不同的数据集有不同的衡量数据量的方式
  • 数据集索引def__getitem__(self, index): 支持数据集索引功能,以实现形如dataset[i]得到数据集中的第i + 1个数据的功能。__getitem__是后期迭代数据时执行的具体函数,其返回值决定了循环变量,例如classdata(Dataset)...def__getitem__(self, idx:int):if self.transforms: img = self.transforms(img)return img, label # 返回的值即为后续迭代的循环变量for images, labels in dataLoader:...

2 数据加载DataLoader

为什么有了数据集

Dataset

还需要数据加载器

DataLoader

呢?原因在于神经网络需要进一步借助

DataLoader

对数据进行划分,也就是我们常说的

batch

,此外

DataLoader

还实现了打乱数据集、多线程等操作。

DataLoader

本质是一个可迭代对象,可以使用形如

for inputs, labels in dataloaders

进行可迭代对象的访问。

我们一般不需要去实现

DataLoader

的接口,只需要在构造函数中指定相应的参数即可,比如常见的

batch_size

shuffle

等参数。

下面这张图非常好地说明了

Dataset

DataLoader

的关系

在这里插入图片描述

接下来总结数据构造的三步法

  1. 继承Dataset对象,并实现__len__()__getitem__()魔法方法,该步骤的主要目的在于将文件形式的数据集处理为模型可用的标准数据格式,并加载到内存中;
  2. DataLoader对象封装Dataset,使其成为可迭代对象;
  3. 遍历DataLoader对象以将数据加载到模型中进行训练。

3 常用预处理方法

在数据集

Dataset

__getitem__()

中利用

torchvision.transforms

进行数据预处理与变换

常见的数据预处理变换方法总结如下表
序号变换含义1

RandomCrop(size, ...)

对输入图像依据给定size随机裁剪2

CenterCrop(size, ...)

对输入图像依据给定size从中心裁剪3

RandomResizedCrop(size, ...)

对输入图像随机长宽比裁剪,再放缩到给定size4

FiveCrop(size, ...)

对输入图像进行上下左右及中心裁剪,返回五张图像(size)组成的四维张量5

TenCrop(size, vertical_flip=False)

对输入图像进行上下左右及中心裁剪,再全部翻转(水平或垂直),返回十张图像(size)组成的四维张量6

RandomHorizontalFlip(p=0.5)

对输入图像按概率p随机进行水平翻转7

RandomVerticalFlip(p=0.5)

对输入图像按概率p随机进行垂直翻转8

RandomRotation(degree, ...)

对输入图像在degree内随机旋转某角度9

Resize(size, ...)

对输入图像重置分辨率10

Normalize(mean, std)

对输入图像各通道进行标准化11

ToTensor()

将输入图像或ndarray 转换为tensor并归一化12

Pad(padding, fill=0, padding_mode=‘constant’)

对输入图像进行填充13

ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

对输入图像修改亮度、对比度、饱和度、色度等14

Grayscale(num_output_channels=1)

对输入图像转灰度15

LinearTransformation(matrix)

对输入图像进行线性变换16

RandomAffine(...)

对输入图像进行仿射变换17

RandomGrayscale(p=0.1)

对输入图像按概率p随机转灰度18

ToPILImage(mode=None)

对输入图像转PIL格式图像19

RandomOrder()

随机打乱transforms操作顺序

4 模型处理

考虑以下场景:

网络的部分层级结构已经收敛、无需调整;大型复杂网络需要**微调(Fine-tune)**某些结构或参数;希望基于已训练好的模型进行改善或其他研究工作。

这些场景下重新通过数据集训练整个神经网络并无必要,甚至会使模型不稳定,因此引入**预训练(pretrained)**。Pytorch允许用户保存已训练好的模型,或加载其他模型,避免往复的无谓重训练,其中模型参数文件以

.pth

为后缀

# 保存已训练模型
torch.save(model.state_dict(), path)# 加载预训练模型
model.load_state_dict(torch.load(path), device)

通过设置模型某些层可学习参数的

requires_grad

属性为

False

即可固定这部分参数不被后续学习过程影响。深度学习框架应用优势之一在于预设了对GPU的支持,大大提高模型处理与训练的效率。Pytorch中通过

mode.to(device)

方法将模型部署到指定设备上(CPU/GPU),范式如下:

device = torch.device("cuda:0"if torch.cuda.is_available()else"cpu")
model.to(device)

工程上也常使用

torch.nn.DataParallel(model, devices)

来处理多GPU并行运算,其原理是:首先将模型加载到主GPU上,再将模型从主GPU产生若干副本到其余GPU,随后将一个batch中的数据按维度划分为不同的子任务给各GPU进行前向传播,得到的损失会被累积到主GPU上并由主GPU反向传播更新参数,最后将更新参数拷贝到其余GPU以开始下一轮训练。

5 实例:MNIST数据集处理

下面给出了处理MNIST手写数据集的完整代码,可以用于加深对数据处理流程的理解

from abc import abstractmethod
import numpy as np
from torchvision.datasets import mnist
from torch.utils.data import Dataset
from PIL import Image

classmnistData(Dataset):'''
    * @breif: MNIST数据集抽象接口
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''def__init__(self, dataPath:str, transforms=None)->None:super().__init__()
        self.dataPath = dataPath
        self.transforms = transforms
        self.data, self.label =[],[]def__len__(self)->int:returnlen(self.label)def__getitem__(self, idx:int):
        img = self.data[idx]if self.transforms:
            img = self.transforms(img)return img, self.label[idx]@abstractmethoddefplot(self, index:int)->None:pass@abstractmethoddefload(self)->list:passdefplotData(self, index:int, info:str=None)->None:'''
        * @breif: 可视化训练数据
        * @param[in]: index -> 数据集索引
        * @param[in]: info -> 备注信息
        * @retval: None
        '''print(info," --index:", index,"--label:", self.label[index])if info else \
        print(" --index:", index,"--label:", self.label[index])          
        img = Image.fromarray(np.uint8(self.data[index]))
        img.show()defloadData(self, train:bool)->list:'''
        * @breif: 下载与加载数据集
        * @param[in]: train -> 是否为训练集
        * @retval: 数据与标签列表
        '''# 如果指定目录下不存在数据集则下载
        dataSet   = mnist.MNIST(self.dataPath, train=train, download=True)# 初始化数据与标签
        data  =[ i[0]for i in dataSet ]
        label =[ i[1]for i in dataSet ]return data, label

classmnistTrainData(mnistData):'''
    * @breif: MNIST训练集
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''def__init__(self, dataPath:str, transforms=None)->None:super().__init__(dataPath, transforms=transforms)
        self.data, self.label = self.load()defplot(self, index:int)->None:
        self.plotData(index,"trainSet data")defload(self)->list:return self.loadData(train=True)classmnistTestData(mnistData):'''
    * @breif: MNIST测试集
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''def__init__(self, dataPath:str, transforms=None)->None:super().__init__(dataPath, transforms=transforms)
        self.data, self.label = self.load()defplot(self, index:int)->None:
        self.plotData(index,"testSet data")defload(self)->list:return self.loadData(train=False)

在这里插入图片描述


本文转载自: https://blog.csdn.net/FRIGIDWINTER/article/details/129644566
版权归原作者 Mr.Winter` 所有, 如有侵权,请联系我们删除。

“Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理”的评论:

还没有评论