如何使用torch.onnx.export
时间: 2023-05-22 20:05:45 浏览: 646
torch.onnx.export 是用来将 PyTorch 模型导出成 ONNX 格式的函数,其具体用法为:
torch.onnx.export(model, # 导出模型
args, # 模型的输入参数
f, # 导出 ONNX 模型的文件名
export_params=True, # 是否导出模型的参数
opset_version=10, # 导出的 ONNX 格式版本
do_constant_folding=True, # 是否进行常量折叠优化
input_names=["input"], # 模型输入的名称
output_names=["output"], # 模型输出的名称
dynamic_axes={"input":{0:"batch_size"}, # 动态图下输入的维度
"output":{0:"batch_size"}}) # 动态图下输出的维度
需要注意的是,torch.onnx.export 可以导出的模型类型包括:nn.Module、nn.Sequential 和 nn.ModuleList。此外,该函数中的参数 f 应该是一个 str 类型的文件名,而不是一个文件句柄。
阅读全文