0


Pytorch入门(五)使用ResNet-18网络训练常规状态下的CIFAR10数据集

本文采用ResNet-18+Pytorch+CIFAR-10实现深度学习的训练。

文章目录

一、CIFAR-10 数据集介绍

CIFAR10数据集是一个用于识别普适物体的小型数据集,一共包含10个类别的RGB彩色图片,图片尺寸大小为32x32,如图:
在这里插入图片描述
相较于MNIST数据集,MNIST数据集是28x28的单通道灰度图,而CIFAR10数据集是32x32的RGB三通道彩色图,CIFAR10数据集更接近于真实世界的图片。这里我采用的是定制CIFAR10数据集,数据集目录结构如下(训练集包含5w张图片,测试集包含1w张图片):
在这里插入图片描述

二、ResNet 神经网络的介绍

1.ResNet 的网络模型

在这里插入图片描述

本文采用ResNet18来构建深度网络模型,下面是ResNet18与ResNet50的对比。
在这里插入图片描述

2.本文用到的ResNet网络结构

本文用到的ResNet-18 的层次结构:

  • 输入层:尺寸为32x32的RGB图像。
  • 卷积层1:64个3x3的卷积核,步长为1,padding为1,生成64个特征图。
  • 批量归一化层1:对卷积层1的输出进行批量归一化操作。
  • ReLU激活函数:对批量归一化层1的输出应用ReLU激活函数。
  • 残差块1:由两个基本的残差单元组成。
  • 残差块2:由两个基本的残差单元组成。
  • 残差块3:由两个基本的残差单元组成。
  • 残差块4:由两个基本的残差单元组成。
  • 全局平均池化层:对最后一个残差块的输出应用全局平均池化操作。
  • 全连接层:将池化层的输出连接到一个全连接层,用于最终的分类操作。
  • Softmax激活函数:对全连接层的输出应用Softmax激活函数,生成最终的概率输出。

我们的ResNet18网络结构示意图大致如下:

  1. 输入
  2. |
  3. 卷积层,643x3的卷积核, 步长1|
  4. 批量归一化
  5. |
  6. ReLU激活函数
  7. |
  8. 残差块1|
  9. 残差块2|
  10. 残差块3|
  11. 残差块4|
  12. 全局平均池化
  13. |
  14. 全连接层, 输出类别数
  15. |
  16. Softmax激活函数
  17. |
  18. 输出

上述的每一个残差块都由两个卷积层组成,具体结构如下:

  1. 残差块:|
  2. 卷积层,643x3的卷积核, 步长1|
  3. 批量归一化
  4. |
  5. ReLU激活函数
  6. |
  7. 卷积层,643x3的卷积核, 步长1|
  8. 批量归一化
  9. |
  10. 跳跃连接
  11. |
  12. ReLU激活函数

3.残差块的的解释

残差块(Residual Block)是 ResNet-18 网络中的重要组成部分,它的作用是帮助网络有效地学习深层特征表示。由于深层神经网络存在梯度消失和梯度爆炸的问题,传统的网络难以有效地训练和优化。残差块的引入有效地解决了这个问题。

残差块的核心思想是引入了一个跳跃连接(skip connection),使得信息可以直接从输入层流经残差块并与残差块的输出相加。这样,网络可以直接学习残差(即差异),而不仅仅是学习特征变换。这种跳跃连接允许梯度在反向传播过程中更容易地传播,从而避免了梯度消失和梯度爆炸问题。

具体来说,残差块中的两个卷积层(或更多卷积层)形成了一种特征变换,将输入特征图映射到更高维度的特征空间。然后,跳跃连接将输入特征图与残差块的输出相加,形成残差。最后,通过对残差应用激活函数,产生残差块的输出。

残差块的存在使得网络能够更好地优化深层网络,加深网络的深度,并在保持网络性能的同时提高训练速度和效果。

4.ResNet神经网络的优缺点

ResNet-18 是一个经典的深度残差网络,在深度学习领域中取得了很大的成功。它具有以下的优点和缺点。

优点:

  • 解决了深层网络中的梯度消失和梯度爆炸问题:通过引入残差块和跳跃连接,ResNet-18 允许梯度在网络中更容易地传播,有助于训练更深的网络。
  • 提高了网络的训练效果和表达能力:深层残差结构有助于网络学习更复杂、更抽象的特征表示,可以提高网络的准确性和泛化能力。
  • 减少了参数数量:相比于传统的网络结构,ResNet-18 的残差块允许跳跃连接,使得网络可以跳过一些不必要的卷积层,从而减少了参数数量,减轻了过拟合的风险。
  • 在计算资源允许的情况下,可以通过增加网络的深度进一步提升性能:ResNet-18 可以作为基础模型,通过增加残差块的数量或者使用更深的变体(如 ResNet-34、ResNet-50 等)来进一步提升性能。

缺点:

  • 模型较为复杂:ResNet-18 的网络结构相对复杂,需要更多的计算资源和存储空间来训练和部署。
  • 对较小的数据集可能会过拟合:由于 ResNet-18 的深度和参数数量较多,当训练数据集较小时,可能会出现过拟合的问题。针对小数据集的训练,可以采用数据增强、正则化等方法来缓解过拟合。
  • 训练时间较长:由于 ResNet-18 较深且复杂,相对于一些浅层网络结构,它的训练时间可能会更长。

总体而言,ResNet-18 是一个非常强大的深度学习网络,它的优点在很多任务上得到了证明,但在特定的应用场景中仍然需要根据具体情况权衡其优缺点。


三、ResNet-18 代码实现

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. # 残差块classResidualBlock(nn.Module):def__init__(self, inchannel, outchannel, stride=1):super(ResidualBlock, self).__init__()
  4. self.left = nn.Sequential(
  5. nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
  6. nn.BatchNorm2d(outchannel),
  7. nn.ReLU(inplace=True),
  8. nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
  9. nn.BatchNorm2d(outchannel))
  10. self.shortcut = nn.Sequential()if stride !=1or inchannel != outchannel:
  11. self.shortcut = nn.Sequential(
  12. nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
  13. nn.BatchNorm2d(outchannel))defforward(self, x):
  14. out = self.left(x)
  15. out += self.shortcut(x)
  16. out = F.relu(out)return out
  17. classResNet(nn.Module):def__init__(self, ResidualBlock, num_classes=10):super(ResNet, self).__init__()
  18. self.inchannel =64
  19. self.conv1 = nn.Sequential(
  20. nn.Conv2d(3,64, kernel_size=3, stride=1, padding=1, bias=False),
  21. nn.BatchNorm2d(64),
  22. nn.ReLU(),)
  23. self.layer1 = self.make_layer(ResidualBlock,64,2, stride=1)
  24. self.layer2 = self.make_layer(ResidualBlock,128,2, stride=2)
  25. self.layer3 = self.make_layer(ResidualBlock,256,2, stride=2)
  26. self.layer4 = self.make_layer(ResidualBlock,512,2, stride=2)
  27. self.fc = nn.Linear(512, num_classes)defmake_layer(self, block, channels, num_blocks, stride):
  28. strides =[stride]+[1]*(num_blocks -1)#strides=[1,1]
  29. layers =[]for stride in strides:
  30. layers.append(block(self.inchannel, channels, stride))
  31. self.inchannel = channels
  32. return nn.Sequential(*layers)defforward(self, x):
  33. out = self.conv1(x)
  34. out = self.layer1(out)
  35. out = self.layer2(out)
  36. out = self.layer3(out)
  37. out = self.layer4(out)
  38. out = F.avg_pool2d(out,4)
  39. out = out.view(out.size(0),-1)
  40. out = self.fc(out)return out
  41. defResNet18():return ResNet(ResidualBlock)

四、ResNet-18 训练 CIFAR-10数据集

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torchvision
  5. import torchvision.transforms as transforms
  6. import argparse
  7. from torch.utils.tensorboard import SummaryWriter
  8. from resnet_model import ResNet18
  9. # 定义是否使用GPU
  10. device = torch.device("cuda"if torch.cuda.is_available()else"cpu")# 参数设置,使得我行差们能够手动输入命令行参数,就是让风格变得和Linux命令不多# parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')# parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') # 输出结果保存路径# parser.add_argument('--net', default='./model/Resnet18.pth', help="path to net (to continue training)") # 恢复训练时的模型路径# args = parser.parse_args()# 超参数设置
  11. EPOCH =200# 遍历数据集次数
  12. pre_epoch =0# 定义已经遍历数据集的次数
  13. BATCH_SIZE =128# 批处理尺寸(batch_size)
  14. LR =0.001# 学习率# print("开始加载CIFAR10数据集!")# 准备数据集并预处理
  15. transform_train = transforms.Compose([
  16. transforms.RandomCrop(32, padding=4),# 先四周填充0,在吧图像随机裁剪成32*32
  17. transforms.RandomHorizontalFlip(),# 图像一半的概率翻转,一半的概率不翻转
  18. transforms.ToTensor(),
  19. transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),# R,G,B每层的归一化用到的均值和方差])
  20. transform_test = transforms.Compose([
  21. transforms.ToTensor(),
  22. transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),])
  23. trainset = torchvision.datasets.ImageFolder(root='data/train', transform=transform_train)# 训练数据集
  24. trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True,
  25. num_workers=2)# 生成一个个batch进行批训练,组成batch的时候顺序打乱取
  26. testset = torchvision.datasets.ImageFolder(root='data/test', transform=transform_test)
  27. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, num_workers=2)# Cifar-10的标签
  28. classes =('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')# print("CIFAR10数据集加载完毕!")# print("开始ResNet网络模型初始化!")# 模型定义-ResNet
  29. resnet18 = ResNet18().to(device)# 定义损失函数和优化方式
  30. loss_fn = nn.CrossEntropyLoss()# 损失函数为交叉熵,多用于多分类问题
  31. loss_fn=loss_fn.to(device)
  32. optimizer = optim.SGD(resnet18.parameters(), lr=LR, momentum=0.9,
  33. weight_decay=5e-4)# 优化方式为mini-batch momentum-SGD,并采用L2正则化(权重衰减)# 记录训练的次数
  34. total_train_step =0# 记录测试的次数
  35. total_test_step =0# 添加tensorboard画图可视化
  36. writer = SummaryWriter("logs_train")# print("ResNet网络模型初始化完毕!")# 训练if __name__ =="__main__":
  37. best_acc =85# 2 初始化best test accuracy
  38. best_epoch=0# 有需要可以打开,接着上次训练好的权重训练# print("加载模型...")# with open("pth/resnet18_12.pth",'rb') as f:# resnet18.load_state_dict(torch.load(f))# print("加载完毕!")print("开始训练! Resnet-18! 冲!")# 定义遍历数据集的次数for epoch inrange(pre_epoch, EPOCH):print(f'--------第{epoch +1}轮训练开始---------')
  39. resnet18.train()
  40. total_train_loss =0.0
  41. correct =0.0
  42. total =0.0for data in trainloader:# print("-------",i)# 准备数据
  43. inputs, labels = data
  44. inputs, labels = inputs.to(device), labels.to(device)
  45. optimizer.zero_grad()# forward + backward
  46. outputs = resnet18(inputs)
  47. loss = loss_fn(outputs, labels)
  48. loss.backward()
  49. optimizer.step()# 每训练100batch打印一次loss和准确率
  50. total_train_loss += loss.item()
  51. _, predicted = torch.max(outputs.data,1)
  52. total += labels.size(0)
  53. total_train_step+=1
  54. correct += predicted.eq(labels.data).cpu().sum()if total_train_step %100==0:print('[训练次数:%d] Loss: %.03f'%(total_train_step, total_train_loss))
  55. writer.add_scalar("train_loss", loss.item(), total_train_step)# 每训练完一个epoch测试一下准确率print("开始测试!")with torch.no_grad():
  56. correct =0
  57. total =0
  58. total_test_loss=0for data in testloader:
  59. resnet18.eval()
  60. images, labels = data
  61. images, labels = images.to(device), labels.to(device)
  62. outputs = resnet18(images)
  63. loss = loss_fn(outputs, labels)
  64. total_test_loss+=loss.item()# 取得分最高的那个类 (outputs.data的索引号)
  65. _, predicted = torch.max(outputs.data,1)
  66. total += labels.size(0)
  67. correct +=(predicted == labels).sum().item()# result = torch.floor_divide(correct, total)# print('测试分类准确率为:%.3f%%' % (100 * result))
  68. acc =100* correct / total
  69. print(f"测试集上的loss:{total_test_loss}")print(f'测试分类准确率为:{acc}')# 将每次测试结果实时写入acc.txt文件中print('Saving model......')
  70. torch.save(resnet18.state_dict(),f'pth/resnet18_{epoch +1}.pth')
  71. writer.add_scalar("test_loss", total_test_loss, total_test_step)
  72. total_test_step = total_test_step +1# 记录最佳测试分类准确率并写入best_acc.txt文件中if acc > best_acc:
  73. f3 =open("best_acc.txt","w")
  74. f3.write(f"训练轮次为{epoch +1}时,准确率最高!准确率为{acc}")
  75. f3.close()
  76. best_acc = acc
  77. print("训练结束!")

五、使用训练好的权重分类

  1. import torch
  2. import torchvision.transforms as transforms
  3. from resnet_model import ResNet18
  4. from PIL import Image
  5. import os
  6. # 定义加载图片的方式# transformed=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()])defpredict_(img):
  7. data_transform = transforms.Compose([
  8. transforms.ToTensor(),
  9. transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),
  10. transforms.Resize((32,32))])if img.mode !="RGB":
  11. img = img.convert("RGB")
  12. img = data_transform(img)
  13. img = torch.unsqueeze(img, dim=0)
  14. model = ResNet18()
  15. model_weight_pth ='pth/resnet18_181.pth'
  16. model.load_state_dict(torch.load(model_weight_pth,map_location="cpu"))
  17. model.eval()
  18. classes ={'0':'飞机','1':'汽车','2':'鸟','3':'猫','4':'鹿','5':'狗','6':'青蛙','7':'马','8':'船','9':'卡车'}with torch.no_grad():
  19. output = torch.squeeze(model(img))print(output)
  20. predict = torch.softmax(output, dim=0)
  21. predict_cla = torch.argmax(predict).numpy()return classes[str(predict_cla)], predict[predict_cla].item()'''
  22. CIFAR10包含哪几类 这10类分别是airplane (飞机),automobile(汽车),bird(鸟),cat(猫),deer(鹿),
  23. dog(狗),frog(青蛙),horse(马),ship(船)和truck(卡车)
  24. '''
  25. basepath=os.path.split(os.path.split(os.getcwd())[0])[0]if __name__=="__main__":while1:
  26. img_path=input("请输入检测图片的名称:")
  27. img=Image.open(basepath+rf"\imgs\{img_path}.png")print(predict_(img))

可以看到这张图片的准确率为0.9999,在整个测试集1w张图片上。
这个权重文件的识别准确率为88.54%。
在这里插入图片描述
在这里插入图片描述

六、实现一个GUI页面

有了上面的权重文件不设计一个GUI页面怎么能配的上。

  1. from PyQt5.QtWidgets import(QWidget,QLCDNumber,QSlider,QMainWindow,
  2. QGridLayout,QApplication,QPushButton, QLabel, QLineEdit)from PyQt5.QtGui import*from PyQt5.QtCore import*from PyQt5.QtWidgets import*import sys
  3. from PyQt5.QtCore import Qt
  4. from resnet_predict import predict_
  5. from PIL import Image
  6. classUi_example(QWidget):def__init__(self):super().__init__()
  7. self.layout = QGridLayout(self)
  8. self.label_image = QLabel(self)
  9. self.label_predict_result = QLabel('识别结果',self)
  10. self.label_predict_result_display = QLabel(self)
  11. self.label_predict_acc = QLabel('识别准确率',self)
  12. self.label_predict_acc_display = QLabel(self)
  13. self.button_search_image = QPushButton('选择图片',self)
  14. self.button_run = QPushButton('运行',self)
  15. self.setLayout(self.layout)
  16. self.initUi()definitUi(self):
  17. self.layout.addWidget(self.label_image,1,1,3,2)
  18. self.layout.addWidget(self.button_search_image,1,3,1,2)
  19. self.layout.addWidget(self.button_run,3,3,1,2)
  20. self.layout.addWidget(self.label_predict_result,4,3,1,1)
  21. self.layout.addWidget(self.label_predict_result_display,4,4,1,1)
  22. self.layout.addWidget(self.label_predict_acc,5,3,1,1)
  23. self.layout.addWidget(self.label_predict_acc_display,5,4,1,1)
  24. self.button_search_image.clicked.connect(self.openimage)
  25. self.button_run.clicked.connect(self.run)
  26. self.setGeometry(300,300,300,300)
  27. self.setWindowTitle('CLFAR-10十分类')
  28. self.show()defopenimage(self):global fname
  29. imgName, imgType = QFileDialog.getOpenFileName(self,"选择图片","","*.jpg;;*.png;;All Files(*)")
  30. jpg = QPixmap(imgName).scaled(self.label_image.width(), self.label_image.height())
  31. self.label_image.setPixmap(jpg)
  32. fname = imgName
  33. defrun(self):global fname
  34. file_name =str(fname)
  35. img = Image.open(file_name)
  36. a, b = predict_(img)
  37. self.label_predict_result_display.setText(a)
  38. self.label_predict_acc_display.setText(str(b))if __name__ =='__main__':
  39. app = QApplication(sys.argv)
  40. ex = Ui_example()
  41. sys.exit(app.exec_())

运行结果如下图所示。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

ResNet 网络中一般不使用全连接层,而是在最后一层使用全局平均池化层和一个全连接层来进行分类。

具体来说,ResNet 的最后一层是一个全局平均池化层,它对最后一个残差块的输出特征图进行平均池化操作,将特征图的高维信息压缩成一个特征向量。

随后,这个特征向量会通过一个全连接层进行分类,将其映射到最终的类别标签上。这个全连接层的输出经过 Softmax 激活函数,生成最终的概率分布。

使用全局平均池化层和一个全连接层可以将整个网络的参数量大大减小,减轻过拟合的风险,并且使网络更容易优化和训练。此外,这种结构也使得网络更加适应不同尺寸的输入图像。

因此,ResNet 网络中使用的全连接层是用于最后的分类操作,而不是用在网络的中间层。


在这里插入图片描述


本文转载自: https://blog.csdn.net/apple_51931783/article/details/130839532
版权归原作者 酷尔。 所有, 如有侵权,请联系我们删除。

“Pytorch入门(五)使用ResNet-18网络训练常规状态下的CIFAR10数据集”的评论:

还没有评论