写一段代码查询网络模型的参数量
时间: 2023-12-01 13:16:03 浏览: 96
pytorch 求网络模型参数实例
当然可以,以下是查询网络模型参数量的代码,假设模型名称为model:
import torch
from torchsummary import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 判断是否使用 GPU
model.to(device)
summary(model, (input_size_1, input_size_2, ...)) # 将模型参数打印出来,输入大小可以自行设置
print(f'总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad):,d}') # 输出总参数量,逗号作为千位分隔符,更易读
阅读全文