图像分类ROC曲线绘制(pytorch)
时间: 2023-10-04 19:03:01 浏览: 98
首先,需要加载必要的库和数据集。这里以MNIST数据集为例。
``` python
import torch
import torchvision
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
# 加载数据集
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False)
```
接下来,需要定义模型和测试函数。这里使用一个简单的卷积神经网络作为模型。
``` python
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 6, 5)
self.pool = torch.nn.MaxPool2d(2, 2)
self.conv2 = torch.nn.Conv2d(6, 16, 5)
self.fc1 = torch.nn.Linear(16 * 4 * 4, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.fc3 = torch.nn.Linear(84, 10)
def forward(self, x):
x = self.pool(torch.nn.functional.relu(self.conv1(x)))
x = self.pool(torch.nn.functional.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return x
# 加载已训练的模型
net = Net()
net.load_state_dict(torch.load('mnist_cnn.pt'))
# 测试函数
def test(net, testloader):
net.eval()
outputs_list = []
targets_list = []
with torch.no_grad():
for i, data in enumerate(testloader):
inputs, targets = data
outputs = net(inputs)
outputs_list.append(outputs)
targets_list.append(targets)
outputs = torch.cat(outputs_list, dim=0)
targets = torch.cat(targets_list, dim=0)
return outputs, targets
```
接下来,可以使用测试函数得到模型在测试集上的输出和真实标签。然后,可以使用ROC曲线来评估模型的性能。
``` python
# 得到模型在测试集上的输出和真实标签
outputs, targets = test(net, testloader)
# 计算ROC曲线
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(10):
fpr[i], tpr[i], _ = roc_curve(targets == i, outputs[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# 绘制ROC曲线
plt.figure()
for i in range(10):
plt.plot(fpr[i], tpr[i], label='ROC curve of digit {0} (area = {1:0.2f})'.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve')
plt.legend(loc="lower right")
plt.show()
```
绘制出来的ROC曲线如下所示。
![ROC curve](https://i.imgur.com/9O9TtFj.png)
阅读全文