如何使用torchstat统计模型参数量
时间: 2023-12-17 17:04:17 浏览: 172
使用torchstat可以很方便地统计模型参数量。具体步骤如下:
1. 安装torchstat:在命令行中输入 `pip install torchstat`
2. 导入需要统计参数量的模型和torchstat:
```python
import torch
import torch.nn as nn
from torchstat import stat
model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(25088, 4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 1000),
nn.ReLU()
)
```
3. 使用 `stat` 函数统计模型参数量:
```python
stat(model, (3, 224, 224))
```
其中 `(3, 224, 224)` 是模型接受的输入尺寸,可以根据需要进行修改。
运行之后会输出模型的参数量、计算量和内存占用等信息。
阅读全文