基于TNN的c++图像分类代码
时间: 2024-02-27 13:53:27 浏览: 114
以下是一个基于TNN的C++图像分类代码示例:
```cpp
#include <iostream>
#include <memory>
#include "tnn/core/macro.h"
#include "tnn/core/tnn.h"
#include "tnn/utils/mat_utils.h"
#include "tnn/utils/dims_vector_utils.h"
#include "tnn/utils/blob_converter.h"
#include "tnn/interpreter/default_model_interpreter.h"
#include "tnn/interpreter/tnn/tnn_interpreter.h"
#include "tnn/interpreter/tnn/tnn_device.h"
#include "tnn/interpreter/tnn/tnn_utils_internal.h"
using namespace TNN_NS;
int main(int argc, char** argv) {
if (argc < 3) {
std::cout << "Usage: " << argv[0] << " proto model image" << std::endl;
return -1;
}
std::string proto_content, model_content;
if (ReadProtoFile(argv[1], &proto_content) != TNN_OK) {
std::cout << "Read proto file failed." << std::endl;
return -1;
}
if (ReadModelFile(argv[2], model_content) != TNN_OK) {
std::cout << "Read model file failed." << std::endl;
return -1;
}
std::shared_ptr<TNN> tnn = std::make_shared<TNN>();
TNNStatus status = tnn->Init(proto_content, model_content, "", TNNComputeUnitsCPU);
if (status != TNN_OK) {
std::cout << "Init TNN failed, error code: " << (int)status << std::endl;
return -1;
}
auto input_dims = tnn->GetInputShape(0);
if (input_dims.size() != 4) {
std::cout << "Invalid input dims." << std::endl;
return -1;
}
auto input_mat = std::make_shared<Mat>(input_dims, MatType::NCHW_FLOAT);
int input_size = DimsVectorUtils::Count(input_dims);
auto converter = std::make_shared<NCBlobConverter>();
size_t input_bytes_size = input_size * sizeof(float);
RawBuffer input_buffer(input_bytes_size);
std::ifstream in_file(argv[3], std::ios::binary);
if (!in_file.is_open()) {
std::cout << "Open image file failed." << std::endl;
return -1;
}
in_file.read(reinterpret_cast<char*>(input_buffer.force_to<void*>()), input_bytes_size);
in_file.close();
converter->ConvertFromHostToDevice(input_buffer, input_mat, nullptr);
std::shared_ptr<TNNInterpreter> interpreter = std::make_shared<TNNInterpreter>();
status = interpreter->Init(tnn->GetModelConfig());
if (status != TNN_OK) {
std::cout << "Init interpreter failed, error code: " << (int)status << std::endl;
return -1;
}
std::shared_ptr<TNNSession> session = interpreter->CreateSession(tnn->GetModelConfig());
if (!session) {
std::cout << "Create session failed." << std::endl;
return -1;
}
status = session->SetInputMat(input_mat);
if (status != TNN_OK) {
std::cout << "Set input mat failed, error code: " << (int)status << std::endl;
return -1;
}
status = session->Forward();
if (status != TNN_OK) {
std::cout << "Forward failed, error code: " << (int)status << std::endl;
return -1;
}
auto output_dims = tnn->GetOutputShape(0);
if (output_dims.size() != 2) {
std::cout << "Invalid output dims." << std::endl;
return -1;
}
auto output_size = DimsVectorUtils::Count(output_dims);
auto output_mat = std::make_shared<Mat>(output_dims, MatType::NCHW_FLOAT);
auto output_buffer = std::make_shared<RawBuffer>(output_size * sizeof(float));
converter->ConvertFromDeviceToHost(output_mat, output_buffer, nullptr);
status = session->GetOutputMat(output_mat);
if (status != TNN_OK) {
std::cout << "Get output mat failed, error code: " << (int)status << std::endl;
return -1;
}
std::vector<float> output_data(output_size);
memcpy(output_data.data(), output_buffer->force_to<void*>(), output_size * sizeof(float));
int max_index = std::max_element(output_data.begin(), output_data.end()) - output_data.begin();
std::cout << "Predicted class index: " << max_index << std::endl;
return 0;
}
```
这段代码会读取一个模型文件和一个图像文件,并使用TNN进行图像分类。注意,这段代码仅用于参考,实际使用时可能需要根据具体情况进行修改。
阅读全文