resnet导出onnx模型时,如何去掉模型的module前缀
时间: 2023-12-13 10:03:30 浏览: 186
清理模型插件
在 PyTorch 中使用 `nn.DataParallel` 进行多 GPU 训练时,会在模型参数前添加 `module.` 前缀,如果导出 ONNX 模型时需要去掉这个前缀,可以使用以下代码:
```python
import torch
import torch.onnx
from collections import OrderedDict
# 加载模型
model = torch.load("model.pth")
# 创建新的有序字典,去掉前缀
new_state_dict = OrderedDict()
for k, v in model.state_dict().items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
# 构建示例输入
x = torch.randn(1, 3, 224, 224)
# 导出 ONNX 模型,去掉前缀
torch.onnx.export(model, x, "model.onnx", verbose=True,
input_names=["input"], output_names=["output"],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
keep_initializers_as_inputs=True, opset_version=12,
dynamic_axes={"input": [0], "output": [0]},
example_outputs=model(x))
# 保存去掉前缀的模型参数
torch.save(new_state_dict, "model_noprefix.pth")
```
这里,我们首先加载已经训练好的 PyTorch 模型,并创建一个新的有序字典 `new_state_dict`,遍历模型参数,去掉参数名中的 `module.` 前缀,然后存入新的有序字典中。
接着,我们构建一个示例输入 `x`,利用 `torch.onnx.export` 导出 ONNX 模型,同时指定输入输出名字、使用哪种运算符、使用哪个 ONNX 版本等参数。注意,这里需要将 `keep_initializers_as_inputs` 设置为 `True`,这样可以将模型的参数也导出为输入,方便后面使用。
最后,我们保存去掉前缀的模型参数,以便后续使用。
阅读全文