0


PyTorch搭建LSTM实现服装分类(FashionMNIST)

FashionMNIST 数据集官网:https://github.com/zalandoresearch/fashion-mnist.

这里不再介绍该数据集,如需了解请前往官网。

思路: 数据集中的每张图片都是尺寸为

  1. (
  2. 28
  3. ,
  4. 28
  5. )
  6. (28,28)
  7. (28,28) 的灰度图。我们可以将其看作
  8. 28
  9. ×
  10. 28
  11. 28\times28
  12. 28×28 的数字矩阵,将该矩阵按行进行**逐行分块**可得一个长度为
  13. 28
  14. 28
  15. 28 的序列,且序列中的每个 “词元” 对应的特征维数也是
  16. 28
  17. 28
  18. 28

运行环境:

  • 系统:Ubuntu 20.04;
  • GPU:RTX 3090;
  • Pytorch:1.11;
  • Python:3.8

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import torch
  4. import torchvision
  5. import torch.nn as nn
  6. from torch.utils.data import DataLoader
  7. # Data Preprocessing
  8. train_data = torchvision.datasets.FashionMNIST(root='data',
  9. train=True,
  10. transform=torchvision.transforms.ToTensor(),
  11. download=True)
  12. test_data = torchvision.datasets.FashionMNIST(root='data',
  13. train=False,
  14. transform=torchvision.transforms.ToTensor(),
  15. download=True)
  16. train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4)
  17. test_loader = DataLoader(test_data, batch_size=64, num_workers=4)# Model buildingclassLSTM(nn.Module):def__init__(self):super().__init__()
  18. self.lstm = nn.LSTM(28,64, num_layers=2)
  19. self.linear = nn.Linear(64,10)defforward(self, x):
  20. output,(h_n, c_n)= self.lstm(x,None)return self.linear(h_n[0])defsetup_seed(seed):
  21. np.random.seed(seed)
  22. torch.manual_seed(seed)
  23. torch.cuda.manual_seed(seed)
  24. torch.cuda.manual_seed_all(seed)# Setup
  25. setup_seed(42)
  26. NUM_EPOCHS =20
  27. LR =4e-3
  28. train_loss, test_loss, test_acc =[],[],[]
  29. device ='cuda'if torch.cuda.is_available()else'cpu'
  30. lstm = LSTM()
  31. lstm.to(device)
  32. critertion = nn.CrossEntropyLoss()
  33. optimizer = torch.optim.Adam(lstm.parameters(), lr=LR)# Training and testingfor epoch inrange(NUM_EPOCHS):print(f'[Epoch {epoch +1}]', end=' ')
  34. avg_train_loss, avg_test_loss, correct =0,0,0# train
  35. lstm.train()for batch_idx,(X, y)inenumerate(train_loader):# (64, 1, 28, 28) -> (28, 64, 28)
  36. X = X.squeeze().movedim(0,1)
  37. X, y = X.to(device), y.to(device)# forward
  38. output = lstm(X)
  39. loss = critertion(output, y)
  40. avg_train_loss += loss
  41. # backward
  42. optimizer.zero_grad()
  43. loss.backward()
  44. optimizer.step()
  45. avg_train_loss /=(batch_idx +1)
  46. train_loss.append(avg_train_loss.item())# test
  47. lstm.eval()with torch.no_grad():for batch_idx,(X, y)inenumerate(test_loader):
  48. X = X.squeeze().movedim(0,1)
  49. X, y = X.to(device), y.to(device)
  50. pred = lstm(X)
  51. loss = critertion(pred, y)
  52. avg_test_loss += loss
  53. correct +=(pred.argmax(1)== y).sum().item()
  54. avg_test_loss /=(batch_idx +1)
  55. test_loss.append(avg_test_loss.item())
  56. correct /=len(test_loader.dataset)
  57. test_acc.append(correct)print(f"train loss: {train_loss[-1]:.4f} | test loss: {test_loss[-1]:.4f} | test acc: {correct:.4f}")# Plot
  58. x = np.arange(1,21)
  59. plt.plot(x, train_loss, label="train loss")
  60. plt.plot(x, test_loss, label="test loss")
  61. plt.plot(x, test_acc, label="test acc")
  62. plt.xlabel("epoch")
  63. plt.legend(loc="best", fontsize=12)
  64. plt.show()

输出结果:

  1. [Epoch 1] train loss:0.6602| test loss:0.5017| test acc:0.8147[Epoch 2] train loss:0.4089| test loss:0.3979| test acc:0.8566[Epoch 3] train loss:0.3577| test loss:0.3675| test acc:0.8669[Epoch 4] train loss:0.3268| test loss:0.3509| test acc:0.8751[Epoch 5] train loss:0.3098| test loss:0.3395| test acc:0.8752[Epoch 6] train loss:0.2962| test loss:0.3135| test acc:0.8854[Epoch 7] train loss:0.2823| test loss:0.3377| test acc:0.8776[Epoch 8] train loss:0.2720| test loss:0.3196| test acc:0.8835[Epoch 9] train loss:0.2623| test loss:0.3120| test acc:0.8849[Epoch 10] train loss:0.2547| test loss:0.2981| test acc:0.8931[Epoch 11] train loss:0.2438| test loss:0.3140| test acc:0.8882[Epoch 12] train loss:0.2372| test loss:0.3043| test acc:0.8909[Epoch 13] train loss:0.2307| test loss:0.2977| test acc:0.8918[Epoch 14] train loss:0.2219| test loss:0.2888| test acc:0.8970[Epoch 15] train loss:0.2187| test loss:0.2946| test acc:0.8959[Epoch 16] train loss:0.2132| test loss:0.2894| test acc:0.8985[Epoch 17] train loss:0.2061| test loss:0.2835| test acc:0.9014[Epoch 18] train loss:0.2028| test loss:0.2954| test acc:0.8971[Epoch 19] train loss:0.1966| test loss:0.2952| test acc:0.8986[Epoch 20] train loss:0.1922| test loss:0.2910| test acc:0.9011

相应的曲线:

在这里插入图片描述


一些心得 :

  • 切勿直接使用 X = X.reshape(28, -1, 28),否则 X 对应的将不是原来的图片(读者可自行尝试使用 torchvision.transforms.ToPILImage 去输出 X 对应的图片观察效果)。
  • 学习率相同的情况下,SGD 的效果没有 Adam 好。
标签: pytorch lstm 分类

本文转载自: https://blog.csdn.net/raelum/article/details/125432230
版权归原作者 raelum 所有, 如有侵权,请联系我们删除。

“PyTorch搭建LSTM实现服装分类(FashionMNIST)”的评论:

还没有评论