img_masked1 = img * torch.from_numpy(mask1).float()报错:The size of tensor a (128) must match the size of tensor b (3) at non-singleton dimension 2
时间: 2023-06-10 21:05:41 浏览: 138
这个错误的原因是img的第三个维度的大小为3,而mask1的第三个维度的大小为1,两个 tensor 的形状不匹配。你需要将 mask1 的形状改为与 img 相同的形状,比如可以将 mask1 在第三个维度上扩展为3,如下所示:
mask1 = np.expand_dims(mask1, axis=2)
img_masked1 = img * torch.from_numpy(mask1).float()
相关问题
img_masked1 = img * torch.from_numpy(mask1).float() 什么意思
这行代码的意思是将一个numpy数组 mask1 转换为 tensor 类型后与 tensor 类型的图像 img 进行逐元素相乘。得到的结果是一个与原图大小相同的 tensor,其中 mask1 的每个元素都与 img 对应位置的元素相乘得到。这种操作可以实现对图像中特定区域进行遮罩或滤波等处理。
把下面的代码用c++实现 import numpy as np import torch import cv2 result = np.fromfile('permute_1_0.raw', dtype=np.float32) img = result.reshape((2,480,480)) torch_img = torch.from_numpy(img) index = torch.max(torch_img, dim=0)[1].numpy().astype(np.uint8) index[index > 0] = 255 # 白色 cv2.imshow('out', index) cv2.waitKey(0)
以下是用C++实现的代码:
```
#include <opencv2/opencv.hpp>
#include <torch/torch.h>
int main() {
// 读取数据
std::ifstream input("permute_1_0.raw", std::ios::binary);
std::vector<float> data(2 * 480 * 480);
input.read(reinterpret_cast<char*>(data.data()), data.size() * sizeof(float));
input.close();
// 转换数据类型
torch::Tensor tensor = torch::from_blob(data.data(), {2, 480, 480}, torch::kFloat32).clone();
// 获取最大值索引
torch::Tensor max = std::get<0>(torch::max(tensor, 0));
torch::Tensor index = max.argmax(0).to(torch::kUInt8);
// 去除背景
index.masked_fill_(index == 0, 255);
// 转换为OpenCV格式
cv::Mat out(index.sizes()[0], index.sizes()[1], CV_8UC1, index.data_ptr());
// 显示图像
cv::imshow("out", out);
cv::waitKey(0);
return 0;
}
```
注意需要安装OpenCV和PyTorch C++ API,并且需要将`permute_1_0.raw`文件放在当前目录下。
阅读全文