nn.Module属性
时间: 2023-07-23 09:14:26 浏览: 85
在 PyTorch 中,每个 nn.Module 对象都具有以下属性:
1. training: 一个布尔值,表示该模块当前是否处于训练模式。如果处于训练模式,则为 True,否则为 False。
2. _parameters: 一个 OrderedDict,包含该模块的可学习参数。每个参数都是一个 Tensor 对象,并且可以通过调用该模块的 parameters() 方法来获取。
3. _buffers: 一个 OrderedDict,包含该模块的所有缓存。缓存通常是状态变量,例如 BatchNorm 中的 running_mean 和 running_var。
4. _modules: 一个 OrderedDict,包含该模块的所有子模块。子模块可以是任何 nn.Module 对象,包括 nn.Sequential 和 nn.ModuleList。
5. _non_persistent_buffers_set: 一个 set,包含该模块所有不持久化的缓存的名称。这些缓存通常是临时变量,例如 BatchNorm 中的保存的中间变量。
6. _backward_hooks: 一个 OrderedDict,包含该模块的所有 backward hook。backward hook 是在反向传播过程中调用的函数,用于修改梯度或执行其他操作。
7. _forward_hooks: 一个 OrderedDict,包含该模块的所有 forward hook。forward hook 是在前向传播过程中调用的函数,用于修改输出或执行其他操作。
这些属性都是在 PyTorch 中使用 nn.Module 时非常重要的。通过使用这些属性,可以方便地查看模块的状态、获取和修改参数、添加子模块等等。
阅读全文