高光谱图像分类结果可视化代码pytorch
时间: 2023-10-21 22:44:47 浏览: 177
以下是一个使用PyTorch的高光谱图像分类结果可视化的代码示例:
```python
import torch
import torchvision
import matplotlib.pyplot as plt
# 加载训练好的模型
model = torch.load('model.pt')
# 加载测试数据集
test_dataset = torchvision.datasets.DatasetFolder('test_data', transform=torchvision.transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
# 获取测试数据和预测结果
data, _ = next(iter(test_loader))
prediction = model(data)
# 可视化分类结果
class_names = test_dataset.classes
predicted_class = torch.argmax(prediction).item()
predicted_label = class_names[predicted_class]
plt.imshow(data.squeeze().permute(1, 2, 0))
plt.title(f'Predicted Label: {predicted_label}')
plt.axis('off')
plt.show()
```
请确保将代码中的`model.pt`替换为你的训练好的模型的路径,`test_data`替换为你的测试数据集的路径。此代码将加载训练好的模型并使用测试数据集中的一张图像进行预测,并可视化预测结果。