tensorrt demo c++
时间: 2023-07-01 13:26:08 浏览: 105
以下是一个简单的TensorRT C++ demo,该demo使用TensorRT推理引擎对MNIST数字进行分类:
```c++
#include <iostream>
#include <cmath>
#include <fstream>
#include <sstream>
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "NvOnnxParser.h"
using namespace nvinfer1;
using namespace plugin;
int main(int argc, char** argv) {
// Load the ONNX model
std::string onnx_model_file = "mnist.onnx";
IBuilder* builder = createInferBuilder(gLogger);
INetworkDefinition* network = builder->createNetwork();
auto parser = nvonnxparser::createParser(*network, gLogger);
parser->parseFromFile(onnx_model_file.c_str(), -1);
builder->setMaxBatchSize(1);
builder->setMaxWorkspaceSize(1 << 30);
// Set the input and output dimensions
Dims input_dims = network->getInput(0)->getDimensions();
input_dims.d[0] = 1; // Set batch size to 1
network->getInput(0)->setDimensions(input_dims);
network->getOutput(0)->setDimensions(Dims4(1, 10, 1, 1));
// Build the engine
ICudaEngine* engine = builder->buildCudaEngine(*network);
// Create execution context
IExecutionContext* context = engine->createExecutionContext();
// Create input and output buffers
void* input_buffer;
void* output_buffer;
cudaMalloc(&input_buffer, input_dims.numel() * sizeof(float));
cudaMalloc(&output_buffer, 10 * sizeof(float));
// Load the input data
float input_data[28 * 28];
std::ifstream input_file("test_input.txt");
std::string line;
int i = 0;
while (getline(input_file, line)) {
std::stringstream ss(line);
ss >> input_data[i++];
}
// Copy the input data to GPU
cudaMemcpy(input_buffer, input_data, input_dims.numel() * sizeof(float), cudaMemcpyHostToDevice);
// Run inference
context->execute(1, &input_buffer, &output_buffer);
// Copy the output back to CPU
float output_data[10];
cudaMemcpy(output_data, output_buffer, 10 * sizeof(float), cudaMemcpyDeviceToHost);
// Print the output
std::cout << "Output: ";
for (int i = 0; i < 10; i++) {
std::cout << output_data[i] << " ";
}
std::cout << std::endl;
// Clean up
cudaFree(input_buffer);
cudaFree(output_buffer);
context->destroy();
engine->destroy();
network->destroy();
builder->destroy();
parser->destroy();
return 0;
}
```
这个demo将一个MNIST手写数字的28x28像素图像作为输入,输出一个包含10个元素的向量,其中每个元素代表一个数字的概率。在执行这个demo之前,需要先将ONNX模型转换为TensorRT格式。