把下面的代码用c++实现,不使用torch import numpy as np import torch import cv2 result = np.fromfile('permute_1_0.raw', dtype=np.float32) img = result.reshape((2,480,480)) torch_img = torch.from_numpy(img) index = torch.max(torch_img, dim=0)[1].numpy().astype(np.uint8) index[index > 0] = 255 # 白色 cv2.imshow('out', index) cv2.waitKey(0)
时间: 2024-03-19 07:45:54 浏览: 150
以下是使用 C++ 实现的代码:
```cpp
#include <iostream>
#include <fstream>
#include <opencv2/opencv.hpp>
#include <torch/torch.h>
using namespace std;
using namespace cv;
int main() {
// 读取二进制文件
ifstream fin("permute_1_0.raw", ios::in | ios::binary);
if (!fin) {
cerr << "Failed to open file." << endl;
return -1;
}
// 读取数据到数组
const int channels = 2;
const int height = 480;
const int width = 480;
float data[channels * height * width];
fin.read(reinterpret_cast<char*>(data), channels * height * width * sizeof(float));
fin.close();
// 将数组转为 OpenCV Mat 对象
Mat img(channels, cv::Size(width, height), CV_32F, data);
// 将 Mat 转为 Tensor 对象
auto tensor_img = torch::from_blob(img.data, {channels, height, width}, torch::kFloat).clone();
// 求取最大值对应的索引
auto indices = torch::argmax(tensor_img, 0, false).to(torch::kByte);
// 将 Tensor 转为 OpenCV Mat 对象
Mat index(height, width, CV_8UC1, indices.data_ptr());
// 处理结果并显示
threshold(index, index, 0, 255, THRESH_BINARY);
imshow("out", index);
waitKey(0);
return 0;
}
```
需要注意的是,本例中使用了 OpenCV 和 PyTorch C++ API。在编译时需要添加相应的头文件和库文件,并确保环境配置正确。
阅读全文