通过 nn.ModuleDict 对象创建 nn.Module
时间: 2023-10-23 12:27:01 浏览: 49
可以通过 nn.ModuleDict 对象来创建一个包含多个子模块的 nn.Module 对象。具体来说,可以使用 nn.ModuleDict 的构造函数创建一个空的字典,然后通过字典的 update 方法逐个添加子模块。每个子模块都需要指定一个名称,这个名称将作为字典中的键。例如,下面的代码创建了一个包含两个子模块的 nn.Module 对象:
```python
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.submodules = nn.ModuleDict()
self.submodules.update({'conv': nn.Conv2d(3, 64, kernel_size=3, padding=1),
'relu': nn.ReLU()})
def forward(self, x):
x = self.submodules['conv'](x)
x = self.submodules['relu'](x)
return x
```
在这个例子中,我们创建了一个名为 MyModule 的 nn.Module 对象,它包含了一个字典 submodules,其中包含了两个子模块:一个卷积层 nn.Conv2d 和一个激活函数 nn.ReLU。在 forward 方法中,我们通过字典的键来逐个调用子模块,实现了一个简单的前向传播过程。