onnxruntime C++ 多batchsize推理
时间: 2023-07-29 19:09:23 浏览: 257
使用 onnxruntime C++ 进行多 batchsize 推理需要使用 onnxruntime C++ API 的 Session 类。在创建 Session 时需要设置 batch_size,然后在调用 Session.Run() 方法时传入多个输入数据。
示例代码如下:
```cpp
#include <iostream>
#include <vector>
#include "onnxruntime_cxx_api.h"
int main() {
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
session_options.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
session_options.SetLogSeverityLevel(1);
// 创建 Session,设置 batch_size
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::Session session(env, "model.onnx", session_options);
Ort::AllocatorWithDefaultOptions allocator;
size_t input_tensor_size = 3 * 224 * 224;
size_t batch_size = 3;
std::vector<float> input_data_1(input_tensor_size * batch_size);
std::vector<float> input_data_2(input_tensor_size * batch_size);
// 多个输入数据
for (size_t i = 0; i < input_tensor_size * batch_size; ++i) {
input_data_1[i] = rand() / (float)(RAND_MAX);
input_data_2[i] = rand() / (float)(RAND_MAX);
}
// 多个输入数据
std::vector<int64_t> input_shape = {batch_size, 3, 224, 224};
std::vector<Ort::Value> input_tensors;
input_tensors.push_back(
Ort::Value::CreateTensor<float>(allocator, input_data_1.data(), input_tensor_size * batch_size, input_shape.data(), input_shape.size()));
input_tensors.push_back(
Ort::Value::CreateTensor<float>(allocator, input_data_2.data(), input_tensor_size * batch_size, input_shape.data(), input_shape.size()));
// 输出数据
std::vector<const char*> output_names = {"output"};
std::vector<Ort::Value> output_tensors;
output_tensors.push_back(Ort::Value::CreateTensor<float>(allocator, output_shape.data(), output_shape.size()));
// 推理
session.Run(Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), output_tensors.data(), output_tensors.size());
return 0;
}
```
在上面的示例代码中,先创建了一个 SessionOptions 对象,并设置了一些属性。然后创建 Session 时设置 batch_size,并指定模型文件路径。在推理之前,先创建多个输入张量,然后调用 Session.Run() 方法分别进行推理。
阅读全文