有一个torch.nn.module的子类,用for循环打印模型的结构、参数、权重
时间: 2024-10-17 08:14:59 浏览: 53
要实现这个功能,你可以创建一个自定义的`nn.Module`子类,并使用`__repr__`方法来展示模型的结构、参数和权重。`__repr__`是一个特殊的方法,当调用对象的`str()`或`repr()`时会被自动调用。
下面是一个简单的例子:
```python
class CustomModule(nn.Module):
def __init__(self):
super(CustomModule, self).__init__()
self.linear1 = nn.Linear(10, 20)
self.linear2 = nn.Linear(20, 30)
def forward(self, x):
return self.linear1(x).relu() + self.linear2(x).relu()
def extra_repr(self):
# 返回模型的描述信息
return f"Structure: {self._get_name()} \nParameters: {self.parameters().__len__()} \nWeights: {self.linear1.weight.shape}, {self.linear2.weight.shape}"
# 创建并初始化模型实例
model = CustomModule()
# 打印模型信息
for param in model.named_parameters():
print(f"{param[0]}: shape={param[1].shape}")
print(model)
```
当你运行这段代码时,它会遍历模块的所有参数,并打印每个参数的名字及其形状。`extra_repr`方法用于返回额外的模型描述信息,包括其结构、参数数量以及两个线性层的权重形状。
阅读全文