文章目录
前言
今天在这里纪录一下如何对torch网络的层进行更改:变更,增加,删除与查找
这里拿VGG16网络举例,先看一下网络结构
import torch
import torch.nn as nn
from torchvision import models
net = models.vgg11(pretrained=True)
一、在网络中添加一层:
net网络是一个树型结构, net下面有三个结点,分别是(features, avgpoll, classifier), 我们先在features结点添加一层’lastlayer’层
net.features.add_module('lastlayer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))
- 在classifier结点添加一个线性层:
net.classifier.add_module('Linear', nn.Linear(1000,10))
二、修改网络中的某一层
- 以features 结点举例
net.features[8]= nn.Conv2d(512,512, kernel_size=(3,3), stride=(1,1), padding=(1,1))
- 以classifier结点举例
net.classifier[6]= nn.Linear(1000,5)
注意: 这里我尝试对Linear这一层进行更新, 但是Linear名字是字符串, 提取不出来,所以应该在之前添加网络时候, 名字不要取字符串, 否则会报错 ‘ 'str' object cannot be interpreted as an integer’。
三、网络层的删除
方法一:使用关键字del删除层(推荐)
删除前
model = prepare_vitmodel('mae_visualize_vit_large_ganloss.pth','vit_large_patch16')del model.head # 删除层
model
删除后
方法二:将层设置为空层
以features举例 classifier结点的操作相同,这里直接使用nn.Sequential()对改层设置为空即可
net.features[13]= nn.Sequential()
四、网络层的切片
net.features = nn.Sequential(*list(net.features.children())[:-4])
可以看到后面4层被去除了, 就是说可以使用列表切片的方法来删除网络层
net.classifier 对应 net.classifier.children()
net.features 对应 net.features.children()
五、网络层的冻结
#冻结指定层的预训练参数:
net.feature[26].weight.requires_grad =False
版权归原作者 夜雨窗中人 所有, 如有侵权,请联系我们删除。