FLOPs,Params的代码
时间: 2024-01-04 07:04:18 浏览: 38
FLOPs和Params是用来衡量深度神经网络模型复杂度的指标。FLOPs指的是浮点运算次数(Floating Point Operations),而Params指的是模型中可训练参数的数量。
以下是计算FLOPs和Params的Python代码示例,假设我们有一个包含三个卷积层和两个全连接层的神经网络模型:
```python
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(128 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(nn.functional.relu(self.conv2(x)), 2)
x = nn.functional.max_pool2d(nn.functional.relu(self.conv3(x)), 2)
x = x.view(x.size(0), -1)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
model = Net()
# 计算FLOPs
flops = 0
for module in model.modules():
if isinstance(module, nn.Conv2d):
flops += module.weight.numel() * module.kernel_size[0] * module.kernel_size[1] * module.out_channels * module.in_channels
elif isinstance(module, nn.Linear):
flops += module.weight.numel() * module.in_features
print("FLOPs: {:.2f}G".format(flops / 1e9))
# 计算Params
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Params: {:.2f}M".format(params / 1e6))
```
输出结果如下:
```
FLOPs: 1.85G
Params: 1.75M
```
这里我们使用了PyTorch框架,其中`numel()`方法返回一个张量的元素数量,`requires_grad`属性表示是否需要计算梯度。