基于pytorchde混淆矩阵求出五分类的每一类准确率代码
时间: 2023-09-17 11:13:36 浏览: 73
以下是一个示例代码,用于计算基于 PyTorch 的五分类任务的混淆矩阵和每一类的准确率:
import torch
from sklearn.metrics import confusion_matrix
# 假设有五个类别,y_true 和 y_pred 分别是真实标签和预测标签
y_true = torch.tensor(y_true)
y_pred = torch.argmax(y_pred, dim=1)
conf_mat = confusion_matrix(y_true, y_pred)
# 计算每一类的准确率
acc_class = []
for i in range(5):
acc = conf_mat[i, i] / sum(conf_mat[i, :])
acc_class.append(acc)
这里使用了 torch.argmax
函数来获取预测概率最大的标签,然后使用 sklearn.metrics
库中的 confusion_matrix
函数来计算混淆矩阵,最后遍历每一行计算准确率。注意,这里假设有五个类别,如果你的任务中类别数量不同,需要相应地修改代码。
相关推荐










