0


pytorch——VGG网络搭建

🚗VGG介绍

VGG 在2014年由牛津大学著名研究组 VGGVisual Geometry Group)提出,

(论文地址:https://arxiv.org/abs/1409.1556)

获得该年 ImageNet 竞赛中 Localization Task(定位任务)第一名和Classification Task(分类任务)第二名。(可以说是非常的厉害)

🚓那VGG它到底厉害在哪里呢?

通过堆叠多个小卷积核来替代大尺度卷积核,可以减少训练参数,同时能保证相同的感受野

🚓那什么是感受野呢?

决定某一层输出结果中一个元素所对应的输入层的区域大小,被称作感受野receptive field

简单来说就是输出feature map上的一个单元 对应 输入层(上层)上的区域大小

🚗例如上图:maxpool1 感受野为2 (意思是上一层1格 对应 下一层2格)

  1. conv1 感受野为5

🚓计算公式

我们的感受野的计算公式为:

F ( i + 1) 为第 i +1 层感受野
Stride 为第 i 层的步距
Ksize 为 卷积核 或 池化核 尺寸

🚓问题一:

堆叠两个3×3的卷积核替代5x5的卷积核,堆叠三个3×3的卷积核替代7x7的卷积核。

(VGG网络中卷积的Stride默认为1)

替代前后感受野是否相同呢?

根据公式

(第一层)Feature map: F(1) = 1

(第二层)Conv3x3(3):F(2)=(F(1)-1) * 1 + 3= (1-1) * 1 + 3 =3

(第三层)Conv3x3(2): F(3)=(F(2)-1)*1 + 3 = (3-1)*1+3=5

(5×5卷积核感受野)

(第四层)Conv3x3(1): F(4)=(F(3)-1)*1+3=(5-1)*1+3=7

(7×7卷积核感受野)

2个3×3的卷积核和一个5x5的卷积核感受野相同

证明可以通过**堆叠两个3×3的卷积核替代5x5的卷积核,堆叠三个3×3的卷积核替代7x7的卷积核 **

🚓问题二:

堆叠3×3卷积核后训练参数是否真的减少了?

注:CNN参数个数 = 卷积核尺寸×卷积核深度 × 卷积核组数 = 卷积核尺寸 × 输入特征矩阵深度 × 输出特征矩阵深度
现假设 输入特征矩阵深度 = 输出特征矩阵深度 = C

使用7×7卷积核所需参数个数:

堆叠三个3×3的卷积核所需参数个数:

**很明显27小于49 **

🚓网络图

VGG网络有多个版本,

我们一般采用VGG16 (16的意思是16层=12层卷积层+4层全连接层

其网络结构如下如所示:

看图和计算我们可以知道,经3×3卷积的特征矩阵的尺寸是不改变的:

out =(in −F+2P)/S+1=(in ​−3+2)/1+1= in

out = in 大小一样

🚗pytorch搭建VGG网络

VGG网络分为 卷积层提取特征全连接层进行分类 这两个模块

🚓1. model.py

  1. import torch.nn as nn
  2. import torch
  3. class VGG(nn.Module):
  4. def __init__(self, features, num_classes=1000, init_weights=False):
  5. super(VGG, self).__init__()
  6. self.features = features # 卷积层提取特征
  7. self.classifier = nn.Sequential( # 全连接层进行分类
  8. nn.Dropout(p=0.5),
  9. nn.Linear(512*7*7, 2048),
  10. nn.ReLU(True),
  11. nn.Dropout(p=0.5),
  12. nn.Linear(2048, 2048),
  13. nn.ReLU(True),
  14. nn.Linear(2048, num_classes)
  15. )
  16. if init_weights:
  17. self._initialize_weights() #初始化权重
  18. def forward(self, x):
  19. # N x 3 x 224 x 224
  20. x = self.features(x)
  21. # N x 512 x 7 x 7
  22. x = torch.flatten(x, start_dim=1)
  23. # N x 512*7*7
  24. x = self.classifier(x)
  25. return x
  26. def _initialize_weights(self):
  27. for m in self.modules():
  28. if isinstance(m, nn.Conv2d):
  29. # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  30. nn.init.xavier_uniform_(m.weight)
  31. if m.bias is not None:
  32. nn.init.constant_(m.bias, 0)
  33. elif isinstance(m, nn.Linear):
  34. nn.init.xavier_uniform_(m.weight)
  35. # nn.init.normal_(m.weight, 0, 0.01)
  36. nn.init.constant_(m.bias, 0)

🚕神奇处理之处

  1. # vgg网络模型配置列表,数字表示卷积核个数,'M'表示最大池化层
  2. cfgs = {
  3. 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # 模型A
  4. 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # 模型B
  5. 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], # 模型D
  6. 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], # 模型E
  7. }
  8. # 卷积层提取特征
  9. def make_features(cfg: list): # 传入的是具体某个模型的参数列表
  10. layers = []
  11. in_channels = 3 # 输入的原始图像(rgb三通道)
  12. for v in cfg:
  13. # 如果是最大池化层,就进行池化
  14. if v == "M":
  15. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  16. # 不然就是卷积层
  17. else:
  18. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  19. layers += [conv2d, nn.ReLU(True)]
  20. in_channels = v
  21. return nn.Sequential(*layers) # 单星号(*)将参数以元组(tuple)的形式导入
  22. def vgg(model_name="vgg16", **kwargs): # 双星号(**)将参数以字典的形式导入
  23. try:
  24. cfg = cfgs[model_name]
  25. except:
  26. print("Warning: model number {} not in cfgs dict!".format(model_name))
  27. exit(-1)
  28. model = VGG(make_features(cfg), **kwargs) #**kwargs是你传入的字典数据
  29. return model

🚓2. train.py

和pytorch——AlexNet——训练花分类数据集_heart_6662的博客-CSDN博客的一样(数据还是花的数据)

  1. import os
  2. import json
  3. import torch
  4. import torch.nn as nn
  5. from torchvision import transforms, datasets
  6. import torch.optim as optim
  7. from tqdm import tqdm
  8. from model import vgg
  9. def main():
  10. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  11. print("using {} device.".format(device))
  12. data_transform = {
  13. "train": transforms.Compose([transforms.RandomResizedCrop(224),
  14. transforms.RandomHorizontalFlip(),
  15. transforms.ToTensor(),
  16. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
  17. "val": transforms.Compose([transforms.Resize((224, 224)),
  18. transforms.ToTensor(),
  19. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
  20. data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
  21. image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path
  22. assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
  23. train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
  24. transform=data_transform["train"])
  25. train_num = len(train_dataset)
  26. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  27. flower_list = train_dataset.class_to_idx
  28. cla_dict = dict((val, key) for key, val in flower_list.items())
  29. # write dict into json file
  30. json_str = json.dumps(cla_dict, indent=4)
  31. with open('class_indices.json', 'w') as json_file:
  32. json_file.write(json_str)
  33. batch_size =32
  34. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
  35. print('Using {} dataloader workers every process'.format(nw))
  36. train_loader = torch.utils.data.DataLoader(train_dataset,
  37. batch_size=batch_size, shuffle=True,
  38. num_workers=0)
  39. validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
  40. transform=data_transform["val"])
  41. val_num = len(validate_dataset)
  42. validate_loader = torch.utils.data.DataLoader(validate_dataset,
  43. batch_size=batch_size, shuffle=False,
  44. num_workers=0)
  45. print("using {} images for training, {} images for validation.".format(train_num,
  46. val_num))
  47. # test_data_iter = iter(validate_loader)
  48. # test_image, test_label = test_data_iter.next()
  49. model_name = "vgg16"
  50. net = vgg(model_name=model_name, num_classes=5, init_weights=True)
  51. net.to(device)
  52. loss_function = nn.CrossEntropyLoss()
  53. optimizer = optim.Adam(net.parameters(), lr=0.0001)
  54. epochs = 30
  55. best_acc = 0.0
  56. save_path = './{}Net.pth'.format(model_name)
  57. train_steps = len(train_loader)
  58. for epoch in range(epochs):
  59. # train
  60. net.train()
  61. running_loss = 0.0
  62. train_bar = tqdm(train_loader)
  63. for step, data in enumerate(train_bar):
  64. images, labels = data
  65. optimizer.zero_grad()
  66. outputs = net(images.to(device))
  67. loss = loss_function(outputs, labels.to(device))
  68. loss.backward()
  69. optimizer.step()
  70. # print statistics
  71. running_loss += loss.item()
  72. train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  73. epochs,
  74. loss)
  75. # validate
  76. net.eval()
  77. acc = 0.0 # accumulate accurate number / epoch
  78. with torch.no_grad():
  79. val_bar = tqdm(validate_loader)
  80. for val_data in val_bar:
  81. val_images, val_labels = val_data
  82. outputs = net(val_images.to(device))
  83. predict_y = torch.max(outputs, dim=1)[1]
  84. acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
  85. val_accurate = acc / val_num
  86. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
  87. (epoch + 1, running_loss / train_steps, val_accurate))
  88. if val_accurate > best_acc:
  89. best_acc = val_accurate
  90. torch.save(net.state_dict(), save_path)
  91. print('Finished Training')
  92. if __name__ == '__main__':
  93. main()

3. predict.py

pytorch——AlexNet——训练花分类数据集_heart_6662的博客-CSDN博客与之前一样

  1. import os
  2. import json
  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. import matplotlib.pyplot as plt
  7. from model import vgg
  8. def main():
  9. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  10. data_transform = transforms.Compose(
  11. [transforms.Resize((224, 224)),
  12. transforms.ToTensor(),
  13. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  14. # load image
  15. img_path = "../tulip.jpg"
  16. assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
  17. img = Image.open(img_path)
  18. plt.imshow(img)
  19. # [N, C, H, W]
  20. img = data_transform(img)
  21. # expand batch dimension
  22. img = torch.unsqueeze(img, dim=0)
  23. # read class_indict
  24. json_path = './class_indices.json'
  25. assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
  26. json_file = open(json_path, "r")
  27. class_indict = json.load(json_file)
  28. # create model
  29. model = vgg(model_name="vgg16", num_classes=5).to(device)
  30. # load model weights
  31. weights_path = "./vgg16Net.pth"
  32. assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
  33. model.load_state_dict(torch.load(weights_path, map_location=device))
  34. model.eval()
  35. with torch.no_grad():
  36. # predict class
  37. output = torch.squeeze(model(img.to(device))).cpu()
  38. predict = torch.softmax(output, dim=0)
  39. predict_cla = torch.argmax(predict).numpy()
  40. print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
  41. predict[predict_cla].numpy())
  42. plt.title(print_res)
  43. for i in range(len(predict)):
  44. print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
  45. predict[i].numpy()))
  46. plt.show()
  47. if __name__ == '__main__':
  48. main()

🚗注意

VGG网络模型深度较深,需要使用算力强大GPU进行训练(而且要内存大一点的GPU,我3050跑不动,pytorch会报错GPU内存不足)

你也可以试试改小batch_size


本文转载自: https://blog.csdn.net/qq_62932195/article/details/122416591
版权归原作者 heart_6662 所有, 如有侵权,请联系我们删除。

“pytorch——VGG网络搭建”的评论:

还没有评论