上述代码overall_acc 输出不出来
时间: 2023-12-19 18:07:59 浏览: 83
overall_overallspillover_
很抱歉,我犯了一个错误。在计算整体准确率时,应该使用 `np.trace()` 函数来计算混淆矩阵对角线上的元素之和。以下是修正后的代码:
```
with torch.no_grad():
class_accuracy_list = []
true_labels = []
predicted_labels = []
for data in testloader:
images, labels = data
images = images.view(len(images), 1, 121).float()
inputs, labels = images.to(device), labels.to(device)
outputs = cnn(inputs)
_, predicted = torch.max(outputs, 1)
true_labels.append(labels.cpu().numpy())
predicted_labels.append(predicted.cpu().numpy())
true_labels = np.concatenate(true_labels)
predicted_labels = np.concatenate(predicted_labels)
cm = confusion_matrix(true_labels, predicted_labels, labels=[0, 1, 2, 3, 4])
print(cm)
for i in range(cm.shape[0]):
acc = cm[i, i] / np.sum(cm[i])
class_accuracy_list.append(acc)
print("Class", i, "Accuracy:", acc)
overall_acc = np.trace(cm) / np.sum(cm)
print("Overall Accuracy:", overall_acc)
```
现在,整体准确率应该可以正确输出了。
阅读全文