linux环境下 基于C++的TNN分割代码,并将分割结果保持成图片
时间: 2024-03-28 20:41:54 浏览: 12
以下是基于C++的TNN分割代码,并将分割结果保存为图片:
```cpp
#include "tnn/core/common.hpp"
#include "tnn/core/context.hpp"
#include "tnn/core/profile.h"
#include "tnn/device/cpu/cpu_device.h"
#include "tnn/device/cpu/cpu_context.h"
#include "tnn/utils/blob_converter.h"
#include "tnn/utils/dims_vector_utils.h"
#include "tnn/utils/naive_compute.h"
#include "tnn/utils/omp_utils.h"
#include "tnn/utils/cpu_utils.h"
#include "tnn/utils/omp_utils.h"
#include "tnn/network/tensorrt/tensorrt_network.h"
#include "tnn/network/tensorrt/tensorrt_common.h"
#include "opencv2/opencv.hpp"
using namespace TNN_NS;
int main(int argc, char** argv) {
// 定义输入图像大小
int input_width = 224;
int input_height = 224;
// 创建网络实例
auto proto_content = fdLoadFile("model.tnnproto");
auto model_content = fdLoadFile("model.tnnmodel");
auto network = std::make_shared<TensorRTNetwork>();
auto status = network->LoadFromBuffer(proto_content, model_content);
CHECK_TNN_OK(status);
// 获取输入输出节点
auto input_name = "input";
auto output_name = "output";
auto input_dims = network->GetInputShape(input_name);
auto output_dims = network->GetOutputShape(output_name);
// 创建输入输出blob
std::shared_ptr<RawBuffer> input_buffer(new RawBuffer(input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float)));
std::shared_ptr<Mat> input_mat = std::make_shared<Mat>(input_dims, input_buffer);
std::shared_ptr<RawBuffer> output_buffer(new RawBuffer(output_dims[1] * output_dims[2] * output_dims[3] * sizeof(float)));
std::shared_ptr<Mat> output_mat = std::make_shared<Mat>(output_dims, output_buffer);
// 创建CPU设备和上下文
std::shared_ptr<CpuDevice> cpu_device(new CpuDevice(DEVICE_NAIVE));
CpuContext cpu_context;
cpu_context.device = cpu_device.get();
// 初始化输入数据
MatUtils::InitRandom(*input_mat);
auto input_data = input_mat->GetData();
// 执行推理
status = network->Forward({ input_mat }, { output_mat }, &cpu_context);
CHECK_TNN_OK(status);
// 输出结果
auto output_data = output_mat->GetData();
cv::Mat seg_map(output_dims[2], output_dims[3], CV_32FC1, output_data);
cv::Mat seg_map_8u;
seg_map.convertTo(seg_map_8u, CV_8UC1, 255.0);
// 保存结果为图片
cv::imwrite("seg_result.png", seg_map_8u);
return 0;
}
```
注意需要替换 `model.tnnproto` 和 `model.tnnmodel` 为你自己的模型文件路径,以及根据你的模型定义输入和输出节点的名称。此外,需要安装 OpenCV 库来读写图片。