0


pytorch中nn.ModuleList()使用方法

定义ModuleList

我们可以将我们需要的层放入到一个集合中,然后将这个集合作为参数传入nn.ModuleList中,但是这个子类并不可以直接使用,因为这个子类并没有实现forward函数,所以要使用还需要放在继承了nn.Module的模型中进行使用。

model_list = nn.ModuleList([nn.Conv2d(1,5,2), nn.Linear(10,2), nn.Sigmoid()])

x = torch.randn(32,3,24,24)for model in model_list:
    model_list(x)

使用ModuleList定义网络

classNet(nn.Module):def__init__(self):super().__init__()
        self.model_list = nn.ModuleList([nn.Conv2d(1,5,2), nn.Linear(10,2), nn.Sigmoid()])defforward(self, x):return self.model_list(x)

打印网络层结构

model = Net()print(model)
Net((model_list): ModuleList((0): Conv2d(1,5, kernel_size=(2,2), stride=(1,1))(1): Linear(in_features=10, out_features=2, bias=True)(2): Sigmoid()))

本文转载自: https://blog.csdn.net/m0_47256162/article/details/127824476
版权归原作者 海洋 之心 所有, 如有侵权,请联系我们删除。

“pytorch中nn.ModuleList()使用方法”的评论:

还没有评论