请在上述代码下面加入每类的准确率
时间: 2024-02-22 22:55:54 浏览: 73
可以使用混淆矩阵来计算每类的准确率,具体方法如下:
1. 对于每个类别i,将混淆矩阵的第i行作为该类别的预测结果,第i列作为该类别的真实结果。
2. 计算该类别的准确率,即正确预测的样本数除以总样本数。
以下是修改后的代码:
```python
with torch.no_grad():
true_labels = []
predicted_labels = []
for data in testloader:
images, labels = data
images = images.view(len(images), 1, 121).float()
inputs, labels = images.to(device), labels.to(device)
outputs = cnn(inputs)
_, predicted = torch.max(outputs, 1)
true_labels.append(labels.cpu().numpy())
predicted_labels.append(predicted.cpu().numpy())
true_labels = np.concatenate(true_labels)
predicted_labels = np.concatenate(predicted_labels)
cm = confusion_matrix(true_labels, predicted_labels, labels=[0, 1, 2, 3, 4])
print(cm)
for i in range(cm.shape[0]):
acc = cm[i, i] / np.sum(cm[i])
print("Class", i, "Accuracy:", acc)
```
运行结果应该类似于:
```
[[ 7 1 0 0 0]
[ 0 27 1 0 0]
[ 0 0 30 0 0]
[ 0 0 0 27 1]
[ 0 1 0 1 100]]
Class 0 Accuracy: 0.875
Class 1 Accuracy: 0.9642857142857143
Class 2 Accuracy: 1.0
Class 3 Accuracy: 0.9642857142857143
Class 4 Accuracy: 0.9803921568627451
```
其中,每个类别的准确率均大于等于0,小于等于1,可以作为模型性能的重要指标之一。
阅读全文