torchvision可视化
时间: 2023-09-18 14:14:28 浏览: 71
torchvision
Torchvision 是 PyTorch 的一个扩展库,用于图像处理和计算机视觉任务。它提供了许多常用的数据集、模型、图像变换等功能,同时还提供了一些可视化工具,方便我们查看和理解数据。
下面介绍几种常用的 Torchvision 可视化方法:
1. 可视化数据集
Torchvision 提供了许多常见的数据集,如 MNIST、CIFAR-10、COCO 等。我们可以通过以下代码可视化 MNIST 数据集中的一些样本:
```python
import torchvision
import matplotlib.pyplot as plt
# 加载 MNIST 数据集
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=None)
# 可视化数据集中的前 10 张图片
fig, axs = plt.subplots(1, 10, figsize=(15, 5))
for idx in range(10):
img, target = trainset[idx]
axs[idx].imshow(img, cmap='gray')
axs[idx].set_title(str(target))
axs[idx].axis('off')
plt.show()
```
2. 可视化模型输出
在训练模型时,我们经常需要查看模型的输出结果,以判断模型是否正确地学习到了数据的特征。可以使用 `torchvision.utils.make_grid()` 函数将模型输出的多张图片拼接成一张大图,方便我们直观地观察模型的学习效果。
以下是一个将模型输出的前 16 张图片可视化的示例代码:
```python
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# 加载 CIFAR-10 数据集并进行预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 加载训练好的模型
net = Net()
net.load_state_dict(torch.load(PATH))
# 获取模型输出的前 16 张图片
dataiter = iter(trainloader)
images, labels = dataiter.next()
outputs = net(images)
# 将模型输出的多张图片拼接成一张大图
img_grid = torchvision.utils.make_grid(images[:16], nrow=4)
# 可视化大图
plt.imshow(img_grid.permute(1, 2, 0))
plt.show()
```
3. 可视化特征图
在卷积神经网络中,每层卷积操作的输出都是一组特征图。我们可以使用 `torchvision.utils.make_grid()` 函数将特征图拼接成一张大图,从而观察不同卷积层的特征图。
以下是一个将模型不同卷积层的特征图可视化的示例代码:
```python
import torch
import torchvision
import matplotlib.pyplot as plt
# 加载 CIFAR-10 数据集并进行预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 加载训练好的模型
net = Net()
net.load_state_dict(torch.load(PATH))
# 提取指定层的特征图
conv_layers = []
for name, layer in net.named_modules():
if isinstance(layer, nn.Conv2d):
conv_layers.append(name)
outputs = []
for name in conv_layers:
x = images
for n, l in net.named_modules():
if n == name:
x = l(x)
outputs.append(x)
# 将特征图拼接成一张大图
img_grid = torchvision.utils.make_grid(outputs, nrow=len(conv_layers))
# 可视化大图
plt.imshow(img_grid.permute(1, 2, 0))
plt.show()
```
以上就是几种常用的 Torchvision 可视化方法。需要注意的是,在可视化特征图时,我们需要提前指定要可视化的卷积层,并通过模型中的 `named_modules()` 函数获取该层的输出。
阅读全文