【AI入门超详细系列】卷积神经网络(CNN)入门指南【Pytorch版】
👋 默子的前言
大家好,我是默子!欢迎来到“默子AI”的世界。今天,我们将深入探索 PyTorch 的强大功能,学习如何使用卷积神经网络(CNN)识别图像数据。
无论你是深度学习的新手,还是希望强化实践经验的开发者,这篇教程都将为你提供详尽的指导和深入的解说。准备好了吗?
让我们一起开启这段充满干货与乐趣的学习之旅吧!
别忘了关注我的公众号“默子AI”,获取更多精彩内容!
🛠 环境配置
Step 1:安装 PyTorch
要开始使用 PyTorch,首先需要在你的开发环境中安装它。PyTorch 支持多种操作系统和硬件加速选项(如 CUDA 用于 GPU 加速)。以下是安装 PyTorch 的基本步骤:
安装步骤:
- 选择适合的安装命令:访问 PyTorch 官方网站 https://pytorch.org/get-started/locally/ 获取适合你系统的安装命令。选择适当的操作系统、包管理器、Python 版本和是否需要 CUDA 支持。
如果使用官网命令安装,那就不用再pip重新装一遍了,直接到验证安装即可。
- 使用
pip
安装:对于大多数用户,使用pip
是最简单的方法。例如,安装最新版本的 PyTorch 和 torchvision(用于图像处理):
pip install torch torchvision
如果你需要 GPU 支持(假设你有 NVIDIA GPU 并已安装 CUDA),你可以选择带有 CUDA 的版本:
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
提示:确保你的系统已安装相应版本的 CUDA 驱动。如果不确定是否需要 CUDA,可以选择 CPU 版本,尽管训练速度会较慢。
- 验证安装:安装完成后,可以通过以下 Python 代码验证 PyTorch 是否正确安装,并检查是否支持 GPU:
import torch
print('CUDA版本:',torch.version.cuda)print('Pytorch版本:',torch.__version__)print('显卡是否可用:','可用'if(torch.cuda.is_available())else'不可用')print('显卡数量:',torch.cuda.device_count())print('是否支持BF16数字格式:','支持'if(torch.cuda.is_bf16_supported())else'不支持')print('当前显卡型号:',torch.cuda.get_device_name())print('当前显卡的CUDA算力:',torch.cuda.get_device_capability())print('当前显卡的总显存:',torch.cuda.get_device_properties(0).total_memory/1024/1024/1024,'GB')print('是否支持TensorCore:','支持'if(torch.cuda.get_device_properties(0).major >=7)else'不支持')print('当前显卡的显存使用率:',torch.cuda.memory_allocated(0)/torch.cuda.get_device_properties(0).total_memory*100,'%')
常见问题及解决方案:
- 版本冲突:建议使用虚拟环境(如
venv
或conda
)来管理不同项目的依赖,避免包版本冲突。 - CUDA 安装问题:确保 CUDA 驱动与 PyTorch 安装版本兼容。参考 官方CUDA安装页面https://developer.nvidia.com/cuda-downloads 获取详细步骤。
注意!🚨
这里大家可能会遇到非常多的问题,比如如何查看自己本地的CUDA版本,如何查看自己本地的GPU版本,安装cudnn等等,这里就不一一赘述了,大家如果遇到问题可以在公众号后台留言或者是自行百度/谷歌。或者是问问AI大模型
📦 导入依赖库
Step 2:导入必要的库
在编写 PyTorch 代码之前,我们需要导入一些核心库。这些库将帮助我们构建、训练和测试我们的 CNN 模型。下面是需要导入的主要库及其用途:
import torch # PyTorch 核心库,提供张量操作和自动微分import torch.nn as nn # 构建神经网络的模块,包含各种层和损失函数import torch.optim as optim # 优化器模块,用于模型参数的更新import torchvision # 图像处理相关库,提供常用数据集和图像变换import torchvision.transforms as transforms # 图像预处理模块,提供各种图像变换操作
库详细说明:
- torch:PyTorch 的核心库,提供多维张量(类似于 NumPy 的数组)以及各种数学运算和自动微分功能。
- torch.nn:包含了构建神经网络所需的各种模块和工具,如层(
nn.Conv2d
)、激活函数(nn.ReLU
)和损失函数(nn.CrossEntropyLoss
)。 - torch.optim:提供了多种优化算法,如随机梯度下降(SGD)、Adam 等,用于更新模型参数以最小化损失函数。
- torchvision:专注于计算机视觉任务,提供了常用的数据集(如 CIFAR-10、ImageNet)和图像处理工具。
- torchvision.transforms:用于对图像进行预处理和数据增强,如裁剪、缩放、归一化等操作。
小贴士:
torchvision
是处理图像数据的强大工具,结合
transforms
可以轻松进行数据预处理和增强,提升模型的泛化能力。
复习一下:导入正确的库和模块是构建和训练神经网络的第一步。理解每个库的作用,有助于更高效地利用 PyTorch 的功能。
📊 数据准备
Step 3:定义数据预处理
在处理图像数据时,预处理是至关重要的一步。良好的数据预处理不仅能提高模型的训练效率,还能提升模型的最终性能。我们将对图像数据进行以下预处理操作:
- 转换为张量(ToTensor):将 PIL 图像或 NumPy 数组转换为 PyTorch 张量,并将像素值从
[0, 255]
缩放到[0, 1]
。 - 归一化(Normalize):对每个通道进行标准化处理,使其均值为 0.5,标准差为 0.5,进一步将值缩放到
[-1, 1]
。归一化有助于加快模型的收敛速度,并提高训练的稳定性。
transform = transforms.Compose([
transforms.ToTensor(),# 将图像转换为 PyTorch 张量
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))# 归一化到 [-1, 1]])
详细解释:
- transforms.Compose:将多个图像变换操作组合在一起,按顺序依次应用。
- **transforms.ToTensor()**:- 将 PIL 图像或 NumPy 数组转换为 PyTorch 张量。- 自动将像素值从
[0, 255]
缩放到[0, 1]
,并将图像维度从(H, W, C)
转换为(C, H, W)
,以符合 PyTorch 的张量格式。 - **transforms.Normalize(mean, std)**:- 对每个通道分别进行归一化处理。- 公式:
output = (input - mean) / std
- 这里的mean
和std
是每个通道的均值和标准差。- 通过标准化,图像数据被缩放到一个更适合模型训练的范围,通常有助于加快收敛速度并提高模型性能。
复习一下:数据预处理是深度学习中的基础步骤。正确的预处理不仅能提高模型的训练效率,还能提升模型的泛化能力。在本例中,将图像转换为张量并归一化,是训练 CNN 的标准做法。
Step 4:加载数据集
我们将使用 CIFAR-10 数据集,这是一个广泛使用的图像分类数据集,包含 10 个类别的彩色图像,如飞机、汽车、鸟、猫等。PyTorch 的
torchvision
模块提供了方便的数据加载器,能够自动下载并加载数据集。
# 加载训练集
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2)# 加载测试集
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2)# 定义类别标签
classes =('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
详细解释:
- torchvision.datasets.CIFAR10:-
root
:指定数据集下载和存储的目录。-train
:设置为True
加载训练集,False
加载测试集。-download
:如果数据集未下载,会自动下载。-transform
:应用前面定义的图像预处理操作。 - torch.utils.data.DataLoader:-
trainloader
和testloader
:用于迭代访问训练集和测试集数据。-batch_size=4
:每个批次加载 4 张图片。较小的批次有助于模型更快地更新参数,但可能增加训练时间。-shuffle=True
:训练集数据会在每个 epoch 开始前打乱,增加训练的随机性,帮助模型更好地泛化。-shuffle=False
:测试集数据不需要打乱,保持顺序即可。-num_workers=2
:使用 2 个子进程加载数据,提升数据加载速度。根据系统配置,可以调整此参数以优化性能。
复习一下:数据加载器是 PyTorch 中处理数据的核心工具。通过合理设置
batch_size
、
shuffle
和
num_workers
,可以显著提高数据加载效率和模型训练效果。
温故知新:在深度学习项目中,数据预处理和加载是确保模型训练顺利进行的基础步骤。正确理解和应用这些操作,能为后续的模型构建和训练打下坚实的基础。
🏗 构建 CNN 模型
Step 5:设计网络结构
卷积神经网络(CNN)是深度学习中处理图像数据的强大工具。CNN 的核心在于卷积操作,它能够自动提取图像中的特征。我们将一步步构建一个简单但功能强大的 CNN 模型,适用于 CIFAR-10 数据集。
classSimpleCNN(nn.Module):def__init__(self):super(SimpleCNN, self).__init__()# 第一层卷积:输入3通道(RGB),输出6个特征图,卷积核大小5x5
self.conv1 = nn.Conv2d(3,6,5)# 最大池化层:2x2窗口,步幅2,将特征图尺寸减半
self.pool = nn.MaxPool2d(2,2)# 第二层卷积:输入6通道,输出16个特征图,卷积核大小5x5
self.conv2 = nn.Conv2d(6,16,5)# 全连接层1:输入16*5*5,输出120个节点
self.fc1 = nn.Linear(16*5*5,120)# 全连接层2:输入120,输出84
self.fc2 = nn.Linear(120,84)# 全连接层3:输入84,输出10,对应10个类别
self.fc3 = nn.Linear(84,10)defforward(self, x):# 前向传播过程
x = self.pool(torch.relu(self.conv1(x)))# 卷积1 + ReLU激活 + 池化
x = self.pool(torch.relu(self.conv2(x)))# 卷积2 + ReLU激活 + 池化
x = x.view(-1,16*5*5)# 展平张量
x = torch.relu(self.fc1(x))# 全连接1 + ReLU激活
x = torch.relu(self.fc2(x))# 全连接2 + ReLU激活
x = self.fc3(x)# 全连接3,输出结果return x
详细解释:
- 类定义:
SimpleCNN
继承自nn.Module
,这是构建任何神经网络的基础类。 __init__
方法:- **self.conv1 = nn.Conv2d(3, 6, 5)
**: - 定义第一层卷积层。- 输入通道数为 3(RGB 图像),输出通道数为 6,即生成 6 个特征图。- 卷积核(滤波器)大小为 5x5。- **self.pool = nn.MaxPool2d(2, 2)
**: - 定义一个最大池化层,窗口大小为 2x2,步幅为 2。- 作用是下采样,减少特征图的尺寸,从而降低计算量和防止过拟合。- **self.conv2 = nn.Conv2d(6, 16, 5)
**: - 定义第二层卷积层。- 输入通道数为 6(来自第一层的输出),输出通道数为 16。- 卷积核大小为 5x5。- **self.fc1 = nn.Linear(16 * 5 * 5, 120)
**: - 定义第一层全连接层。- 输入特征数为 16 * 5 * 5(16 个 5x5 的特征图展平后的大小)。- 输出特征数为 120。- **self.fc2 = nn.Linear(120, 84)
**: - 定义第二层全连接层。- 输入特征数为 120,输出特征数为 84。- **self.fc3 = nn.Linear(84, 10)
**: - 定义第三层全连接层。- 输入特征数为 84,输出特征数为 10,对应 CIFAR-10 的 10 个类别。forward
方法:- 定义数据的前向传播路径,即数据如何通过各层进行计算。- **x = self.pool(torch.relu(self.conv1(x)))
**: - 输入数据通过第一层卷积层conv1
进行卷积操作。- 经过 ReLU 激活函数增加非线性。- 最后通过池化层pool
进行下采样。- **x = self.pool(torch.relu(self.conv2(x)))
**: - 输出再次通过第二层卷积层conv2
,ReLU 激活和池化层。- **x = x.view(-1, 16 * 5 * 5)
**: - 将多维特征图展平为一维张量,以便输入到全连接层。--1
表示自动计算该维度的大小,确保数据总量不变。- **x = torch.relu(self.fc1(x))
**: - 通过第一层全连接层fc1
,并应用 ReLU 激活。- **x = torch.relu(self.fc2(x))
**: - 通过第二层全连接层fc2
,并应用 ReLU 激活。- **x = self.fc3(x)
**: - 通过第三层全连接层fc3
,输出最终的分类结果。
详细解释:
- 卷积层(Conv2d):通过多个滤波器扫描输入图像,提取不同的特征(如边缘、纹理)。每个滤波器在整个图像上滑动,生成一个特征图。
- 激活函数(ReLU):增加网络的非线性能力,帮助模型学习复杂的模式。ReLU 的计算简单且有效,能够加快训练速度。
- 池化层(MaxPool2d):通过取局部区域的最大值,减少特征图的尺寸和参数数量,同时保留重要的特征。这有助于防止过拟合。
- 全连接层(Linear):将卷积层提取的空间特征映射到最终的分类结果。全连接层将所有输入特征综合考虑,适用于高层次的特征组合。
复习一下:构建一个 CNN 模型涉及多个层的堆叠,每一层都有特定的功能和参数。理解每一层的作用和参数设置,有助于你更好地设计和优化自己的模型。
Step 6:实例化模型
现在,我们创建模型的实例,为训练做好准备。
net = SimpleCNN()print(net)
详细解释:
- 实例化模型:通过调用
SimpleCNN()
,我们创建了一个SimpleCNN
类的实例net
,这个实例包含了我们定义的所有层和参数。 - 打印模型结构:
print(net)
将输出模型的详细结构,包括每一层的名称、类型和参数。这有助于我们验证模型的正确性,并了解模型的整体架构。示例输出:SimpleCNN( (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True))
复习一下:实例化模型后,通过打印模型结构,可以直观地了解模型的各层配置,帮助你验证模型是否按照预期构建。
🧮 损失函数和优化器
Step 7:选择损失函数
损失函数用于衡量模型预测值与真实值之间的差距,是指导模型优化的关键指标。对于多分类问题,我们选择 交叉熵损失函数(CrossEntropyLoss),因为它在分类任务中表现出色。
criterion = nn.CrossEntropyLoss()
详细解释:
- **nn.CrossEntropyLoss()**: - 结合了
LogSoftmax
和NLLLoss
(负对数似然损失)。- 适用于多分类问题,模型输出不需要经过 softmax 层,因为CrossEntropyLoss
已经内部处理了。- 计算方式:对于每个样本,交叉熵损失衡量真实类别的概率分布与预测概率分布之间的差异。
数学公式:
CrossEntropyLoss = − ∑ c = 1 C y c log ( y ^ c ) \text{CrossEntropyLoss} = -\sum_{c=1}^{C} y_{c} \log(\hat{y}_{c}) CrossEntropyLoss=−c=1∑Cyclog(y^c)
其中,$ C $ 是类别数,$ y_{c} $ 是真实标签的 one-hot 编码,$ \hat{y}_{c} $ 是预测概率。
复习一下:选择合适的损失函数是成功训练模型的关键。交叉熵损失函数适用于多分类问题,能够有效地指导模型优化,提高分类准确率。
Step 8:选择优化器
优化器负责更新模型参数,以最小化损失函数。这里我们选择 随机梯度下降(SGD)优化器,并设置学习率和动量。
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
详细解释:
- optim.SGD: - SGD(Stochastic Gradient Descent):一种经典的优化算法,通过计算损失函数相对于模型参数的梯度,并沿梯度下降的方向更新参数。- 参数: -
net.parameters()
:传入模型的所有参数,优化器将更新这些参数。-lr=0.001
:学习率,控制每次参数更新的步长。较小的学习率可能导致训练速度慢,而较大的学习率可能导致训练不稳定或发散。-momentum=0.9
:动量,用于加速 SGD 在相关方向上的收敛,减少震荡。动量越大,累积的历史梯度越多,更新过程越平滑。
优化器选择的影响:
- 不同的优化器(如 Adam、RMSprop)在不同任务和数据集上表现不同。
- SGD 结合动量在许多任务上表现良好,尤其是在需要稳定收敛的情况下。
复习一下:优化器的选择和参数设置直接影响模型的训练效果和收敛速度。合理设置学习率和动量,有助于模型更快更好地学习。
温故知新:在实际应用中,可以尝试不同的优化器和参数组合,观察模型的训练表现,选择最适合你任务的优化策略。
🏋️♂️ 模型训练
Step 9:编写训练代码
现在,我们将编写训练循环,让模型在训练数据上学习。每个 epoch 将遍历整个训练集一次,逐步优化模型参数以最小化损失函数。
for epoch inrange(2):# 训练2个 epoch
running_loss =0.0for i, data inenumerate(trainloader,0):
inputs, labels = data
# 清空梯度
optimizer.zero_grad()# 前向传播
outputs = net(inputs)
loss = criterion(outputs, labels)# 反向传播
loss.backward()
optimizer.step()# 累加损失
running_loss += loss.item()if i %2000==1999:# 每2000个批次打印一次print(f'[Epoch: {epoch +1}, Batch: {i +1}] Loss: {running_loss /2000:.3f}')
running_loss =0.0print('🏁 Finished Training')
详细解释:
- 训练循环结构:
for epoch inrange(2):...
- epoch:表示训练的轮数,每个 epoch 包含对整个训练集的完整遍历。- **range(2)**:设置训练 2 个 epoch。根据任务复杂度和数据量,可以调整此值。 - 遍历数据:
for i, data inenumerate(trainloader,0): inputs, labels = data
- **enumerate(trainloader, 0)**:遍历训练数据加载器,每次迭代返回一个批次的数据(输入和标签)。- inputs:输入图像数据,形状为[batch_size, channels, height, width]
。- labels:对应的真实标签,形状为[batch_size]
。 - 清空梯度:
optimizer.zero_grad()
- 在每次参数更新前,需要清除之前累积的梯度,否则梯度会在每次迭代时累加。-optimizer.zero_grad()
会将所有参数的梯度清零。 - 前向传播:
outputs = net(inputs)loss = criterion(outputs, labels)
- **outputs = net(inputs)**:将输入数据传入模型,获取预测输出。- **loss = criterion(outputs, labels)**:计算预测输出与真实标签之间的损失值。 - 反向传播:
loss.backward()
- 计算损失函数相对于模型参数的梯度(即反向传播)。- 这些梯度将用于更新模型参数。 - 更新参数:
optimizer.step()
- 根据计算得到的梯度,使用优化器更新模型参数,以最小化损失函数。 - 监控训练过程:
running_loss += loss.item()if i %2000==1999:print(f'[Epoch: {epoch +1}, Batch: {i +1}] Loss: {running_loss /2000:.3f}') running_loss =0.0
- running_loss:累积损失值,用于计算平均损失。- if i % 2000 == 1999:每训练 2000 个批次,打印一次当前的平均损失。- **loss.item()**:获取当前批次的损失值(标量)。- **print(…)**:输出当前 epoch 和批次的损失,帮助监控训练进展。
详细步骤:
- 遍历 Epoch:每个 epoch 表示训练数据集的完整遍历。增加 epoch 数量可以让模型有更多机会学习数据特征,但过多的 epoch 可能导致过拟合。
- 获取数据:从数据加载器中获取输入数据和对应的标签。
- 清空梯度:在每次迭代前清除之前的梯度,避免梯度累加。
- 前向传播:通过模型计算输出,并计算损失值。
- 反向传播:计算梯度,准备更新模型参数。
- 更新参数:优化器根据梯度更新模型参数。
- 监控损失:定期打印损失值,观察训练过程中的模型表现。
复习一下:训练过程中,损失值的逐步降低意味着模型在不断学习和优化,表现也在逐步提升。通过监控损失值,可以及时调整训练策略,如学习率调整、模型结构修改等。
温故知新:在实际应用中,可以通过绘制损失曲线和准确率曲线,直观地了解模型的训练进展和性能变化。
🧪 模型测试
Step 10:在测试集上评估模型
训练完成后,我们需要使用测试集评估模型的性能,计算整体准确率。这一步帮助我们了解模型在未见过的数据上的表现,从而评估其泛化能力。
correct =0
total =0with torch.no_grad():# 测试时不需要计算梯度for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data,1)
total += labels.size(0)
correct +=(predicted == labels).sum().item()print(f'✅ Accuracy of the network on the 10000 test images: {100* correct / total:.2f}%')
详细解释:
- 变量初始化:
correct =0total =0
- correct:记录正确预测的样本数。- total:记录总的样本数。 - 关闭梯度计算:
with torch.no_grad():...
- 在测试阶段,不需要进行反向传播和梯度计算,使用torch.no_grad()
可以节省内存和计算资源,提高测试速度。 - 遍历测试数据:
for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data,1) total += labels.size(0) correct +=(predicted == labels).sum().item()
- images, labels = data:获取测试集中的输入图像和真实标签。- **outputs = net(images)**:将输入图像传入模型,获取预测输出。- **_, predicted = torch.max(outputs.data, 1)**: -torch.max(outputs.data, 1)
:沿着类别维度(维度 1)取最大值,返回值和索引。-_
:保留最大值,但不使用。- predicted:每个样本的预测类别索引。- **total += labels.size(0)**:累加当前批次的样本数。- **correct += (predicted == labels).sum().item()**:比较预测结果与真实标签,累加正确预测的样本数。 - 计算并输出准确率:
print(f'✅ Accuracy of the network on the 10000 test images: {100* correct / total:.2f}%')
- 计算整体准确率,并以百分比形式输出。
复习一下:在测试阶段,通过计算模型在测试集上的准确率,可以评估其在实际应用中的表现。高准确率意味着模型在未见过的数据上具有良好的泛化能力。
温故知新:除了准确率,还可以使用其他评估指标,如精确率、召回率和 F1 分数,进一步深入了解模型的性能。
Step 11:按类别查看准确率
为了更细致地了解模型在各个类别上的表现,我们可以计算每个类别的准确率。这有助于发现模型在哪些类别上表现良好,哪些类别需要进一步优化。
class_correct =list(0.for i inrange(10))
class_total =list(0.for i inrange(10))with torch.no_grad():for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs,1)
c =(predicted == labels).squeeze()for i inrange(len(labels)):
label = labels[i]
class_correct[label]+= c[i].item()
class_total[label]+=1for i inrange(10):if class_total[i]>0:
accuracy =100* class_correct[i]/ class_total[i]else:
accuracy =0print(f'Accuracy of {classes[i]:5s} : {accuracy:.2f}%')
详细解释:
- 变量初始化:
class_correct =list(0.for i inrange(10))class_total =list(0.for i inrange(10))
- class_correct:记录每个类别的正确预测数,初始值为 0。- class_total:记录每个类别的总样本数,初始值为 0。 - 遍历测试数据:
with torch.no_grad():for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs,1) c =(predicted == labels).squeeze()for i inrange(len(labels)): label = labels[i] class_correct[label]+= c[i].item() class_total[label]+=1
- **c = (predicted == labels).squeeze()**:- 比较预测结果与真实标签,生成一个布尔张量,表示每个样本是否预测正确。-squeeze()
:移除尺寸为 1 的维度,简化张量形状。- 遍历每个样本:for i inrange(len(labels)): label = labels[i] class_correct[label]+= c[i].item() class_total[label]+=1
- **label = labels[i]**:获取第i
个样本的真实标签。- **class_correct[label] += c[i].item()**:如果预测正确,累加对应类别的正确预测数。- class_total[label] += 1:累加对应类别的总样本数。 - 计算并输出每个类别的准确率:
for i inrange(10):if class_total[i]>0: accuracy =100* class_correct[i]/ class_total[i]else: accuracy =0print(f'Accuracy of {classes[i]:5s} : {accuracy:.2f}%')
- 遍历每个类别,计算准确率。- 如果某个类别的总样本数为 0,准确率设为 0,避免除零错误。- 输出每个类别的准确率,格式化为两位小数。
复习一下:按类别计算准确率可以帮助我们发现模型在不同类别上的表现差异。某些类别可能因为数据量少或特征复杂而表现较差,需要进一步优化。
温故知新:深入分析模型在各个类别上的表现,可以指导我们进行有针对性的改进,如增加数据量、调整模型结构或使用数据增强技术。
💾 模型保存与加载
Step 12:保存模型
训练好的模型可以保存下来,以便日后复用或部署。我们使用
torch.save()
保存模型的状态字典(
state_dict
),这是一种推荐的保存方式,因为它只保存模型的参数,而不包含模型的结构。
torch.save(net.state_dict(),'cnn_model.pth')# 保存模型参数print('💾 Model saved to cnn_model.pth')
详细解释:
- **torch.save()**:- 用于保存 PyTorch 对象到文件。-
net.state_dict()
:返回模型的状态字典,包含了模型的所有参数(权重和偏置)。-'cnn_model.pth'
:保存文件的路径和名称,通常使用.pth
或.pt
作为扩展名。 - 为何保存
state_dict
而不是整个模型:- 灵活性:保存state_dict
只包含参数,不包含模型结构。加载时需要重新定义模型结构,确保与保存时一致。- 兼容性:适用于不同的代码环境,不依赖于代码文件的完整性。
复习一下:保存模型的状态字典是最佳实践,既能保留模型的学习成果,又保持了灵活性和兼容性。
Step 13:加载模型
如果需要使用保存的模型,可以通过加载状态字典来恢复模型参数。以下是加载模型的步骤:
net = SimpleCNN()# 重新创建模型实例
net.load_state_dict(torch.load('cnn_model.pth'))# 加载参数
net.eval()# 设置模型为评估模式print('📥 Model loaded from cnn_model.pth')
详细解释:
- 重新创建模型实例:
net = SimpleCNN()
- 创建一个新的SimpleCNN
实例,结构必须与保存时一致。 - 加载状态字典:
net.load_state_dict(torch.load('cnn_model.pth'))
- **torch.load(‘cnn_model.pth’)**:从文件中加载保存的状态字典。- **net.load_state_dict(…)**:将加载的参数赋值给模型实例。 - 设置模型为评估模式:
net.eval()
- 将模型设置为评估模式,影响模型中某些层的行为,如 Dropout 和 BatchNorm。- 在评估模式下,Dropout 层不会随机丢弃神经元,BatchNorm 层使用全局均值和方差。
注意:
- 一致性:确保加载参数时模型结构与保存时一致,否则会导致错误。
- 评估模式:在进行模型推理或评估时,必须调用
net.eval()
,以确保模型行为正确。
复习一下:模型的保存与加载是实际应用中不可或缺的步骤。通过保存
state_dict
,我们可以轻松地复用和部署模型,确保模型的学习成果得以保留和应用。
温故知新:在实际项目中,常常需要保存多个版本的模型,或者根据不同需求加载不同的模型参数。合理管理模型文件,有助于项目的可维护性和扩展性。
📚 总结
🎉 恭喜你!通过这篇详尽且专业的教程,你已经掌握了使用 PyTorch 构建、训练和评估一个基础的卷积神经网络(CNN)所需的所有步骤。从环境配置、数据预处理、模型设计,到训练、测试,再到模型的保存与加载,你已经全面了解了一个完整的深度学习工作流。
📈 后续提升建议
- 增加训练轮数(Epochs):尝试训练更多轮数,观察模型性能的提升。更多的训练轮数通常能让模型学习得更充分,但也要注意防止过拟合。
- 调整网络结构:增加卷积层或全连接层,改变滤波器数量,探索不同结构对性能的影响。更深或更宽的网络可能提取更丰富的特征,但也会增加计算量和过拟合风险。
- 尝试不同优化器:例如 Adam 优化器,看看能否加快收敛速度或提升准确率。不同的优化器在不同任务和数据集上表现各异,选择合适的优化器有助于提升模型性能。
optimizer = optim.Adam(net.parameters(), lr=0.001)
- 数据增强:通过随机裁剪、旋转、翻转等方法扩展数据集,提高模型的泛化能力。例如,使用
transforms
进行数据增强:transform = transforms.Compose([ transforms.RandomHorizontalFlip(),# 随机水平翻转 transforms.RandomCrop(32, padding=4),# 随机裁剪并填充 transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
- 迁移学习:使用预训练模型,如 ResNet、VGG 等,进行微调,提升模型性能。迁移学习能够利用在大型数据集上预训练的模型权重,加速训练过程并提升准确率。
import torchvision.models as models# 加载预训练的 ResNet18 模型resnet = models.resnet18(pretrained=True)# 替换最后的全连接层,适应 CIFAR-10 的 10 个类别resnet.fc = nn.Linear(resnet.fc.in_features,10)# 使用新的模型进行训练net = resnet
复习一下:持续学习和实践是掌握深度学习的关键。通过不断调整模型结构、优化训练策略和应用新技术,你可以不断提升模型的性能和应用范围。
温故知新:深度学习领域发展迅速,保持对新技术和方法的关注,有助于你在实际项目中应用最前沿的技术,提升竞争力。
🔍 别忘了,持续学习和实践是掌握深度学习的关键!如果你喜欢这篇教程,欢迎关注我们的公众号“默子AI”📱,获取更多实用的技术干货和最新资讯。让我们一起在 AI 的世界中不断前行,探索更多可能!
随时欢迎你的加入,开启你的 AI 之旅吧!🚀
版权归原作者 默子要早睡.Histone 所有, 如有侵权,请联系我们删除。