pytorch如何展示图片
时间: 2023-07-09 09:31:34 浏览: 91
展示图片可以使用 PyTorch 中的 `torchvision` 模块。具体步骤如下:
1. 导入需要的模块:
```python
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, transforms
```
2. 定义数据变换:
```python
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
```
3. 加载数据集:
```python
testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
```
4. 取出一批数据:
```python
dataiter = iter(testloader)
images, labels = dataiter.next()
```
5. 将数据转换为 numpy 数组,并将像素值还原:
```python
images = images.numpy()
images = ((images * 0.5) + 0.5) * 255
images = images.astype(np.uint8)
```
6. 展示图片:
```python
fig, axes = plt.subplots(figsize=(10, 4), ncols=4)
for i in range(4):
ax = axes[i]
ax.imshow(images[i].squeeze(), cmap='gray')
ax.axis('off')
ax.set_title(str(labels[i].item()))
plt.show()
```
这样就能够展示图片了。以上示例展示了 MNIST 数据集中的四张图片。你也可以根据需要修改数据集和数据变换的设置。
阅读全文