怎么查看onnx模型的精度是32位还是8位的,python代码实例
时间: 2023-11-27 16:55:21 浏览: 27
可以使用ONNX Runtime的API来查看ONNX模型的精度是32位还是8位。下面是一个Python代码示例:
```python
import onnx
import onnxruntime as ort
# 加载ONNX模型
model = onnx.load("model.onnx")
# 查看模型输入和输出的数据类型
input_type = model.graph.input[0].type.tensor_type.elem_type
output_type = model.graph.output[0].type.tensor_type.elem_type
print("Input data type: ", input_type)
print("Output data type: ", output_type)
# 创建ONNX Runtime的推理会话
sess = ort.InferenceSession("model.onnx")
# 获取模型的输入和输出名称
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
# 创建一个随机输入张量
import numpy as np
input_data = np.random.rand(1, 3, 224, 224).astype(np.float32)
# 进行推理
output = sess.run([output_name], {input_name: input_data})
# 查看推理结果的数据类型
print("Output data type: ", output[0].dtype)
```
在上面的代码中,我们首先加载ONNX模型并打印出输入和输出的数据类型。然后,我们创建一个ONNX Runtime的推理会话并获取模型的输入和输出名称。接下来,我们创建一个随机输入张量并进行推理。最后,我们打印输出结果的数据类型。
如果输入和输出的数据类型是`numpy.float32`,那么模型的精度是32位。如果输入和输出的数据类型是`numpy.uint8`,那么模型的精度是8位。