定义ModuleList
我们可以将我们需要的层放入到一个集合中,然后将这个集合作为参数传入nn.ModuleList中,但是这个子类并不可以直接使用,因为这个子类并没有实现forward函数,所以要使用还需要放在继承了nn.Module的模型中进行使用。
model_list = nn.ModuleList([nn.Conv2d(1,5,2), nn.Linear(10,2), nn.Sigmoid()])
x = torch.randn(32,3,24,24)for model in model_list:
model_list(x)
使用ModuleList定义网络
classNet(nn.Module):def__init__(self):super().__init__()
self.model_list = nn.ModuleList([nn.Conv2d(1,5,2), nn.Linear(10,2), nn.Sigmoid()])defforward(self, x):return self.model_list(x)
打印网络层结构
model = Net()print(model)
Net((model_list): ModuleList((0): Conv2d(1,5, kernel_size=(2,2), stride=(1,1))(1): Linear(in_features=10, out_features=2, bias=True)(2): Sigmoid()))
本文转载自: https://blog.csdn.net/m0_47256162/article/details/127824476
版权归原作者 海洋 之心 所有, 如有侵权,请联系我们删除。
版权归原作者 海洋 之心 所有, 如有侵权,请联系我们删除。