pytoch统计网络参数和flops的代码
时间: 2024-04-06 07:30:10 浏览: 36
统计网络参数和FLOPs是深度学习中常用的操作,可以通过以下代码来实现:
```python
import torch
from torchsummary import summary
# 定义模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = torch.nn.Linear(64*8*8, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = torch.nn.functional.relu(self.conv1(x))
x = self.pool(x)
x = torch.nn.functional.relu(self.conv2(x))
x = self.pool(x)
x = x.view(-1, 64*8*8)
x = torch.nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
# 统计参数和FLOPs
model = Net()
input_data = torch.randn(1, 3, 32, 32)
summary(model, input_data)
# 输出结果
'''
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 32, 32] 896
MaxPool2d-2 [-1, 32, 16, 16] 0
Conv2d-3 [-1, 64, 16, 16] 18,496
MaxPool2d-4 [-1, 64, 8, 8] 0
Linear-5 [-1, 128] 524,416
Linear-6 [-1, 10] 1,290
================================================================
Total params: 545,098
Trainable params: 545,098
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.50
Params size (MB): 2.08
Estimated Total Size (MB): 2.59
----------------------------------------------------------------
FLOPs(2D): 111,411,200
FLOPs(3D): 18,564,864
----------------------------------------------------------------
'''
```
这里使用了 `torchsummary` 库来统计模型参数和FLOPs,在代码中 `Net` 类定义了一个简单的卷积神经网络,`summary` 函数可以输出模型各层的参数和FLOPs。其中, `FLOPs(2D)` 表示网络的二维FLOPs,即在输入数据为二维时的FLOPs, `FLOPs(3D)` 表示网络的三维FLOPs,即在输入数据为三维(如3D卷积神经网络)时的FLOPs。
阅读全文