pytorch读取图像
时间: 2023-11-18 09:04:02 浏览: 93
PyTorch提供了多种读取图像的方法,包括使用PIL、OpenCV和torchvision.io等库。下面是三种常用的读取图像的方法:
1. 使用PIL库:
```
from PIL import Image
import torch
path = '../dataset/train/hazy/8902.jpg'
img = Image.open(path)
img = img.convert('RGB')
img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float()
print(type(img))
print(img.shape)
```
输出结果为:
```
<class 'torch.Tensor'>
torch.Size([3, 344, 550])
```
2. 使用OpenCV库:
```
import cv2 as cv
import torch
path = '../dataset/train/hazy/8902.jpg'
img = cv.imread(path)
img = cv.cvtColor(img, code=cv.COLOR_BGR2RGB)
img = torch.from_numpy(img).permute(2, 0, 1).float()
print(type(img))
print(img.shape)
```
输出结果为:
```
<class 'torch.Tensor'>
torch.Size([3, 344, 550])
```
3. 使用torchvision.io库:
```
from torchvision.io import image
import torch
path = '../dataset/train/hazy/8902.jpg'
img = image.read_image(path)
img = img.float()
print(type(img))
print(img.shape)
```
输出结果为:
```
<class 'torch.Tensor'>
torch.Size([3, 344, 550])
```
注意事项:
1. 读取的图像需要转换为RGB格式。
2. 读取的图像需要转换为torch.Tensor类型。
3. 读取的图像需要进行通道维度的转换。
阅读全文