torch.summary
时间: 2023-07-29 10:14:27 浏览: 59
torch.summary是PyTorch中的一个函数,可以用来查看模型的结构和参数数量。它接受一个模型实例和输入数据的大小,并打印出模型的每一层的结构以及总共的参数数量。
以下是一个使用torch.summary函数的示例:
```python
import torch
import torch.nn as nn
from torchsummary import summary
# 定义一个简单的模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.fc = nn.Linear(64 * 32 * 32, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 创建一个模型实例
model = MyModel()
# 打印模型结构和参数数量
summary(model, (3, 32, 32))
```
运行以上代码,将会输出模型的结构和每一层的输出大小,以及总共的参数数量。这对于了解和调试模型非常有用。