0


Pytorch中获取模型摘要的3种方法

在pytorch中获取模型的可训练和不可训练的参数,层名称,内核大小和数量。

Pytorch nn.Module 类中没有提供像与Keras那样的可以计算模型中可训练和不可训练的参数的数量并显示模型摘要的方法 。所以在这篇文章中,我将总结我知道三种方法来计算Pytorch模型中可训练和不可训练的参数的数量。

直接手写代码

最直接的办法就是我们自己手写代码代码实现这个功能,所以这里我自己实现了一个函数,函数中为了漂亮所以引入了PrettyTable的包

  1. from prettytable import PrettyTable
  2. def count_parameters(model):
  3. table = PrettyTable([“Modules”, Parameters”])
  4. total_params = 0
  5. for name, parameter in model.named_parameters():
  6. if not parameter.requires_grad: continue
  7. params = parameter.numel()
  8. table.add_row([name, params])
  9. total_params+=params
  10. print(table)
  11. print(fTotal Trainable Params: {total_params}”)
  12. return total_params

我们拿RESNET18为例,以上函数的输出如下:

  1. +------------------------------+------------+
  2. | Modules | Parameters |
  3. +------------------------------+------------+
  4. | conv1.weight | 9408 |
  5. | bn1.weight | 64 |
  6. | bn1.bias | 64 |
  7. | layer1.0.conv1.weight | 36864 |
  8. | layer1.0.bn1.weight | 64 |
  9. | layer1.0.bn1.bias | 64 |
  10. .
  11. .
  12. .
  13. | fc.weight | 512000 |
  14. | fc.bias | 1000 |
  15. +------------------------------+------------+
  16. Total Trainable Params: 11689512

输出以参数为单位,可以看到模型中存在的每个参数的可训练参数,是不是和keras的基本一样。

torchsummary

torchsummary出现的时候的目标就是为了让torch有类似keras一样的打印模型参数的功能,它非常友好并且十分简单。当前版本为1.5.1,可以直接使用pip安装:

  1. pip install torchsummary

安装完成后即可使用,我们还是以resnet18为例

  1. from torchsummary import summary
  2. model = torchvision.models.resnet18().cuda()

在使用时,我们需要生成一个模型的输入变量,也就是模拟模型的前向传播的过程:

  1. summary(model, input_size = (3, 64, 64), batch_size = -1)

结果如下:

  1. Layer (type) Output Shape Param # ================================================================
  2. Conv2d-1 [-1, 64, 112, 112] 9,408
  3. BatchNorm2d-2 [-1, 64, 112, 112] 128
  4. ReLU-3 [-1, 64, 112, 112] 0
  5. MaxPool2d-4 [-1, 64, 56, 56] 0
  6. Conv2d-5 [-1, 64, 56, 56] 36,864
  7. .
  8. .
  9. .
  10. AdaptiveAvgPool2d-67 [-1, 512, 1, 1] 0
  11. Linear-68 [-1, 1000] 513,000 ================================================================
  12. Total params: 11,689,512
  13. Trainable params: 11,689,512
  14. Non-trainable params: 0
  15. ----------------------------------------------------------------
  16. Input size (MB): 0.57
  17. Forward/backward pass size (MB): 62.79
  18. Params size (MB): 44.59
  19. Estimated Total Size (MB): 107.96
  20. ----------------------------------------------------------------

现在,如果你的基本模型有多个分支,每个分支都有不同的输入,例如

  1. class Model(torch.nn.Module):
  2. def __init__(self):
  3. super().__init__()
  4. self.resnet1 = torchvision.models.resnet18().cuda()
  5. self.resnet2 = torchvision.models.resnet18().cuda()
  6. self.resnet3 = torchvision.models.resnet18().cuda()
  7. def forward(self, *x):
  8. out1 = self.resnet1(x[0])
  9. out2 = self.resnet2(x[1])
  10. out3 = self.resnet3(x[2])
  11. out = torch.cat([out1, out2, out3], dim = 0)
  12. return out

那么就需要这样:

  1. summary(Model().cuda(), input_size = [(3, 64, 64)]*3)

该输出将与前一个相似,但会有点混乱,因为torchsummary将每个组成的ResNet模块的信息压缩到一个摘要中,而在两个连续模块的摘要之间没有任何适当的可区分边界。

torchinfo

它看起来可能与torchsummary类似。但在我看来,它是我找到这三种方法中最好的。torchinfo当前版本是1.7.0,还是可以使用pip安装:

  1. pip install torchinfo

这个包也有一个名为summary的函数。但它有更多的参数。他的使用参数为model (nn.module)、input_size (Sequence of Sizes)、input_data (Sequence of Tensors)、batch_dim (int)、cache_forward_pass (bool)、col_names (Iterable[str])、col_width (int)、depth (int)、device (torch.Device)、dtypes (List[torch.dtype])、mode (str)、row_settings (Iterable[str])、verbose (int)和**kwargs。

参数很多,但是可以直接通过(" input_size ", " output_size ", " num_params ", " kernel_size ", " mult_add ", " trainable ")作为col_names参数来获取信息。

  1. import torchinfo
  2. torchinfo.summary(model, (3, 224, 224), batch_dim = 0, col_names = (“input_size”, output_size”, num_params”, kernel_size”, mult_adds”), verbose = 0)

需要说明的是,如果不使用Jupyter或Google Colab,需要将verbose 更改为1。

上述代码段的输出看起来像这样

  1. =============================================================================================
  2. Layer (type:depth-idx) Input Shape Output Shape Param # Kernel Shape Mult-Adds
  3. =============================================================================================
  4. ResNet [1, 3, 224, 224] [1, 1000] -- -- --
  5. ├─Conv2d: 1-1 [1, 3, 224, 224] [1, 64, 112, 112] 9,408 [7, 7] 118,013,952
  6. ├─BatchNorm2d: 1-2 [1, 64, 112, 112] [1, 64, 112, 112] 128 -- 128
  7. ├─ReLU: 1-3 [1, 64, 112, 112] [1, 64, 112, 112] -- -- --
  8. ├─MaxPool2d: 1-4 [1, 64, 112, 112] [1, 64, 56, 56] -- 3 --
  9. ├─Sequential: 1-5 [1, 64, 56, 56] [1, 64, 56, 56] -- -- --
  10. └─BasicBlock: 2-1 [1, 64, 56, 56] [1, 64, 56, 56] -- -- --
  11. └─Conv2d: 3-1 [1, 64, 56, 56] [1, 64, 56, 56] 36,864 [3, 3] 115,605,504
  12. └─BatchNorm2d: 3-2 [1, 64, 56, 56] [1, 64, 56, 56] 128 -- 128
  13. └─ReLU: 3-3 [1, 64, 56, 56] [1, 64, 56, 56] -- -- --
  14. └─Conv2d: 3-4 [1, 64, 56, 56] [1, 64, 56, 56] 36,864 [3, 3] 115,605,504
  15. └─BatchNorm2d: 3-5 [1, 64, 56, 56] [1, 64, 56, 56] 128 -- 128
  16. └─ReLU: 3-6 [1, 64, 56, 56] [1, 64, 56, 56] -- -- --
  17. └─BasicBlock: 2-2 [1, 64, 56, 56] [1, 64, 56, 56] -- -- --
  18. └─Conv2d: 3-7 [1, 64, 56, 56] [1, 64, 56, 56] 36,864 [3, 3] 115,605,504
  19. └─BatchNorm2d: 3-8 [1, 64, 56, 56] [1, 64, 56, 56] 128 -- 128
  20. └─ReLU: 3-9 [1, 64, 56, 56] [1, 64, 56, 56] -- -- --
  21. └─Conv2d: 3-10 [1, 64, 56, 56] [1, 64, 56, 56] 36,864 [3, 3] 115,605,504
  22. └─BatchNorm2d: 3-11 [1, 64, 56, 56] [1, 64, 56, 56] 128 -- 128
  23. └─ReLU: 3-12 [1, 64, 56, 56] [1, 64, 56, 56] -- -- --
  24. ├─Sequential: 1-6 [1, 64, 56, 56] [1, 128, 28, 28] -- -- --
  25. └─BasicBlock: 2-3 [1, 64, 56, 56] [1, 128, 28, 28] -- -- --
  26. └─Conv2d: 3-13 [1, 64, 56, 56] [1, 128, 28, 28] 73,728 [3, 3] 57,802,752
  27. └─BatchNorm2d: 3-14 [1, 128, 28, 28] [1, 128, 28, 28] 256 -- 256
  28. .
  29. .
  30. .
  31. └─Conv2d: 3-49 [1, 512, 7, 7] [1, 512, 7, 7] 2,359,296 [3, 3] 115,605,504
  32. └─BatchNorm2d: 3-50 [1, 512, 7, 7] [1, 512, 7, 7] 1,024 -- 1,024
  33. └─ReLU: 3-51 [1, 512, 7, 7] [1, 512, 7, 7] -- -- --
  34. ├─AdaptiveAvgPool2d: 1-9 [1, 512, 7, 7] [1, 512, 1, 1] -- -- --
  35. ├─Linear: 1-10 [1, 512] [1, 1000] 513,000 -- 513,000
  36. =============================================================================================
  37. Total params: 11,689,512
  38. Trainable params: 11,689,512
  39. Non-trainable params: 0
  40. Total mult-adds (G): 1.81
  41. =============================================================================================
  42. Input size (MB): 0.60
  43. Forward/backward pass size (MB): 39.75
  44. Params size (MB): 46.76
  45. Estimated Total Size (MB): 87.11
  46. =============================================================================================

再继续查看多分支模型

  1. torchinfo.summary(Model().cuda(), [(3, 64, 64)]*3, batch_dim = 0, col_names = (“input_size”, output_size”, num_params”, kernel_size”, mult_adds”), verbose = 0)

产生以下输出

  1. =============================================================================================
  2. Layer (type:depth-idx) Input Shape Output Shape Param # Kernel Shape Mult-Adds
  3. =============================================================================================
  4. Model [1, 3, 64, 64] [1, 1000] -- -- --
  5. ├─ResNet: 1-1 [1, 3, 64, 64] [1, 1000] -- -- --
  6. └─Conv2d: 2-1 [1, 3, 64, 64] [1, 64, 32, 32] 9,408 [7, 7] 9,633,792
  7. └─BatchNorm2d: 2-2 [1, 64, 32, 32] [1, 64, 32, 32] 128 -- 128
  8. └─ReLU: 2-3 [1, 64, 32, 32] [1, 64, 32, 32] -- -- --
  9. └─MaxPool2d: 2-4 [1, 64, 32, 32] [1, 64, 16, 16] -- 3 --
  10. └─Sequential: 2-5 [1, 64, 16, 16] [1, 64, 16, 16] -- -- --
  11. └─BasicBlock: 3-1 [1, 64, 16, 16] [1, 64, 16, 16] 73,984 -- 18,874,624
  12. └─BasicBlock: 3-2 [1, 64, 16, 16] [1, 64, 16, 16] 73,984 -- 18,874,624
  13. └─Sequential: 2-6 [1, 64, 16, 16] [1, 128, 8, 8] -- -- --
  14. └─BasicBlock: 3-3 [1, 64, 16, 16] [1, 128, 8, 8] 230,144 -- 14,680,832
  15. └─BasicBlock: 3-4 [1, 128, 8, 8] [1, 128, 8, 8] 295,424 -- 18,874,880
  16. └─Sequential: 2-7 [1, 128, 8, 8] [1, 256, 4, 4] -- -- --
  17. └─BasicBlock: 3-5 [1, 128, 8, 8] [1, 256, 4, 4] 919,040 -- 14,681,600
  18. └─BasicBlock: 3-6 [1, 256, 4, 4] [1, 256, 4, 4] 1,180,672 -- 18,875,392
  19. └─Sequential: 2-8 [1, 256, 4, 4] [1, 512, 2, 2] -- -- --
  20. └─BasicBlock: 3-7 [1, 256, 4, 4] [1, 512, 2, 2] 3,673,088 -- 14,683,136
  21. └─BasicBlock: 3-8 [1, 512, 2, 2] [1, 512, 2, 2] 4,720,640 -- 18,876,416
  22. └─AdaptiveAvgPool2d: 2-9 [1, 512, 2, 2] [1, 512, 1, 1] -- -- --
  23. └─Linear: 2-10 [1, 512] [1, 1000] 513,000 -- 513,000
  24. ├─ResNet: 1-2 [1, 3, 64, 64] [1, 1000] -- -- --
  25. └─Conv2d: 2-11 [1, 3, 64, 64] [1, 64, 32, 32] 9,408 [7, 7] 9,633,792
  26. └─BatchNorm2d: 2-12 [1, 64, 32, 32] [1, 64, 32, 32] 128 -- 128
  27. └─ReLU: 2-13 [1, 64, 32, 32] [1, 64, 32, 32] -- -- --
  28. └─MaxPool2d: 2-14 [1, 64, 32, 32] [1, 64, 16, 16] -- 3 --
  29. └─Sequential: 2-15 [1, 64, 16, 16] [1, 64, 16, 16] -- -- --
  30. └─BasicBlock: 3-9 [1, 64, 16, 16] [1, 64, 16, 16] 73,984 -- 18,874,624
  31. └─BasicBlock: 3-10 [1, 64, 16, 16] [1, 64, 16, 16] 73,984 -- 18,874,624
  32. └─Sequential: 2-16 [1, 64, 16, 16] [1, 128, 8, 8] -- -- --
  33. └─BasicBlock: 3-11 [1, 64, 16, 16] [1, 128, 8, 8] 230,144 -- 14,680,832
  34. └─BasicBlock: 3-12 [1, 128, 8, 8] [1, 128, 8, 8] 295,424 -- 18,874,880
  35. └─Sequential: 2-17 [1, 128, 8, 8] [1, 256, 4, 4] -- -- --
  36. └─BasicBlock: 3-13 [1, 128, 8, 8] [1, 256, 4, 4] 919,040 -- 14,681,600
  37. └─BasicBlock: 3-14 [1, 256, 4, 4] [1, 256, 4, 4] 1,180,672 -- 18,875,392
  38. └─Sequential: 2-18 [1, 256, 4, 4] [1, 512, 2, 2] -- -- --
  39. └─BasicBlock: 3-15 [1, 256, 4, 4] [1, 512, 2, 2] 3,673,088 -- 14,683,136
  40. └─BasicBlock: 3-16 [1, 512, 2, 2] [1, 512, 2, 2] 4,720,640 -- 18,876,416
  41. └─AdaptiveAvgPool2d: 2-19 [1, 512, 2, 2] [1, 512, 1, 1] -- -- --
  42. └─Linear: 2-20 [1, 512] [1, 1000] 513,000 -- 513,000
  43. ├─ResNet: 1-3 [1, 3, 64, 64] [1, 1000] -- -- --
  44. └─Conv2d: 2-21 [1, 3, 64, 64] [1, 64, 32, 32] 9,408 [7, 7] 9,633,792
  45. └─BatchNorm2d: 2-22 [1, 64, 32, 32] [1, 64, 32, 32] 128 -- 128
  46. └─ReLU: 2-23 [1, 64, 32, 32] [1, 64, 32, 32] -- -- --
  47. └─MaxPool2d: 2-24 [1, 64, 32, 32] [1, 64, 16, 16] -- 3 --
  48. └─Sequential: 2-25 [1, 64, 16, 16] [1, 64, 16, 16] -- -- --
  49. └─BasicBlock: 3-17 [1, 64, 16, 16] [1, 64, 16, 16] 73,984 -- 18,874,624
  50. └─BasicBlock: 3-18 [1, 64, 16, 16] [1, 64, 16, 16] 73,984 -- 18,874,624
  51. └─Sequential: 2-26 [1, 64, 16, 16] [1, 128, 8, 8] -- -- --
  52. └─BasicBlock: 3-19 [1, 64, 16, 16] [1, 128, 8, 8] 230,144 -- 14,680,832
  53. └─BasicBlock: 3-20 [1, 128, 8, 8] [1, 128, 8, 8] 295,424 -- 18,874,880
  54. └─Sequential: 2-27 [1, 128, 8, 8] [1, 256, 4, 4] -- -- --
  55. └─BasicBlock: 3-21 [1, 128, 8, 8] [1, 256, 4, 4] 919,040 -- 14,681,600
  56. └─BasicBlock: 3-22 [1, 256, 4, 4] [1, 256, 4, 4] 1,180,672 -- 18,875,392
  57. └─Sequential: 2-28 [1, 256, 4, 4] [1, 512, 2, 2] -- -- --
  58. └─BasicBlock: 3-23 [1, 256, 4, 4] [1, 512, 2, 2] 3,673,088 -- 14,683,136
  59. └─BasicBlock: 3-24 [1, 512, 2, 2] [1, 512, 2, 2] 4,720,640 -- 18,876,416
  60. └─AdaptiveAvgPool2d: 2-29 [1, 512, 2, 2] [1, 512, 1, 1] -- -- --
  61. └─Linear: 2-30 [1, 512] [1, 1000] 513,000 -- 513,000
  62. =============================================================================================
  63. Total params: 35,068,536
  64. Trainable params: 35,068,536
  65. Non-trainable params: 0
  66. Total mult-adds (M): 445.71
  67. =============================================================================================
  68. Input size (MB): 0.15
  69. Forward/backward pass size (MB): 9.76
  70. Params size (MB): 140.27
  71. Estimated Total Size (MB): 150.18
  72. =============================================================================================

可以看到depth 参数的默认值为3。并且在可视化方向上,多分支被重新进行了组织并且以层次结构方式呈现,所以很容易区分,所以他的效果要比torchsummary好很多。

作者:Siladittya Manna

“Pytorch中获取模型摘要的3种方法”的评论:

还没有评论