pytorch中计算参数量的代码
时间: 2024-02-27 16:57:52 浏览: 87
下面是用PyTorch计算模型参数量的示例代码:
```python
import torch
import torch.nn as nn
# 定义一个模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2)
x = x.view(-1, 64 * 8 * 8)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
# 创建一个模型实例
model = MyModel()
# 统计模型参数量
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of parameters: {:,}'.format(num_params))
```
以上代码中,我们首先定义了一个简单的卷积神经网络模型`MyModel`,然后创建了一个模型实例。最后,我们使用`sum()`函数和`numel()`方法统计了模型中所有需要训练的参数的数量,并打印出了结果。注意,在计算参数量时,我们只统计了需要训练的参数,因此不包括偏置项等不需要训练的参数。
阅读全文