0


pytorch对网络层的增加,删除,变更和切片

文章目录


前言

今天在这里纪录一下如何对torch网络的层进行更改:变更,增加,删除与查找
这里拿VGG16网络举例,先看一下网络结构

import torch
import torch.nn as nn
from torchvision import models

net = models.vgg11(pretrained=True)

在这里插入图片描述

一、在网络中添加一层:

net网络是一个树型结构, net下面有三个结点,分别是(features, avgpoll, classifier), 我们先在features结点添加一层’lastlayer’层

net.features.add_module('lastlayer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))

在这里插入图片描述

  • 在classifier结点添加一个线性层:
net.classifier.add_module('Linear', nn.Linear(1000,10))

在这里插入图片描述

二、修改网络中的某一层

  • 以features 结点举例
net.features[8]= nn.Conv2d(512,512, kernel_size=(3,3), stride=(1,1), padding=(1,1))

在这里插入图片描述

  • 以classifier结点举例
net.classifier[6]= nn.Linear(1000,5)
注意: 这里我尝试对Linear这一层进行更新, 但是Linear名字是字符串, 提取不出来,所以应该在之前添加网络时候, 名字不要取字符串, 否则会报错 ‘ 'str' object cannot be interpreted as an integer’。 

在这里插入图片描述

三、网络层的删除

方法一:使用关键字del删除层(推荐)

删除前
在这里插入图片描述

model = prepare_vitmodel('mae_visualize_vit_large_ganloss.pth','vit_large_patch16')del model.head  # 删除层
model

删除后
在这里插入图片描述

方法二:将层设置为空层

以features举例 classifier结点的操作相同,这里直接使用nn.Sequential()对改层设置为空即可

net.features[13]= nn.Sequential()

在这里插入图片描述

四、网络层的切片

net.features = nn.Sequential(*list(net.features.children())[:-4])

可以看到后面4层被去除了, 就是说可以使用列表切片的方法来删除网络层
net.classifier 对应 net.classifier.children()
net.features 对应 net.features.children()
在这里插入图片描述

五、网络层的冻结

#冻结指定层的预训练参数:
net.feature[26].weight.requires_grad =False

本文转载自: https://blog.csdn.net/Gw2092330995/article/details/129738155
版权归原作者 夜雨窗中人 所有, 如有侵权,请联系我们删除。

“pytorch对网络层的增加,删除,变更和切片”的评论:

还没有评论