给出一个统计三分类混淆矩阵数据的pytorch代码
时间: 2024-03-14 08:48:04 浏览: 23
以下是一个统计三分类混淆矩阵数据的 PyTorch 代码:
```python
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
# 定义真实标签和预测标签
y_true = torch.tensor([0, 0, 1, 1, 2, 2])
y_pred = torch.tensor([0, 1, 1, 2, 2, 0])
# 计算混淆矩阵
cm = confusion_matrix(y_true.numpy(), y_pred.numpy())
print(cm)
# 计算分类准确率
acc = np.trace(cm) / np.sum(cm)
print(acc)
```
以上代码中,首先定义了真实标签 `y_true` 和预测标签 `y_pred`。然后使用 `confusion_matrix` 函数计算混淆矩阵,并输出结果。最后,根据混淆矩阵计算出分类准确率并输出结果。
相关问题
基于pytorch的一维数据的五分类混淆矩阵代码
以下是基于 PyTorch 的一维数据的五分类混淆矩阵代码示例:
```python
import torch
import numpy as np
# 假设有 100 个样本,每个样本的分类标签为 0, 1, 2, 3, 4 中的一个
y_true = torch.randint(low=0, high=5, size=(100,))
# 假设模型预测的分类为随机值,模拟混淆矩阵
y_pred = torch.randint(low=0, high=5, size=(100,))
# 计算混淆矩阵
confusion_matrix = torch.zeros(5, 5)
for i in range(len(y_true)):
confusion_matrix[y_true[i], y_pred[i]] += 1
# 输出混淆矩阵
print("Confusion Matrix:")
print(confusion_matrix)
```
该代码首先生成了 100 个随机标签的样本,以及一个随机预测的分类结果。然后,通过循环计算出混淆矩阵,并输出结果。
需要注意的是,该代码仅适用于一维数据的情况。对于多维数据,需要对每一个维度进行类似的计算。同时,该代码中假设分类标签为 0, 1, 2, 3, 4,如有需要,可以根据实际情况进行修改。
混淆矩阵pytorch代码
混淆矩阵是用于评估分类模型性能的一种常用工具。在PyTorch中,可以使用以下代码计算混淆矩阵:
```python
import torch
import numpy as np
def confusion_matrix(preds, labels, num_classes):
conf_matrix = torch.zeros(num_classes, num_classes)
preds = torch.argmax(preds, 1)
for p, t in zip(preds, labels):
conf_matrix[p, t] += 1
return conf_matrix
# 示例用法
preds = torch.tensor([0, 1, 2, 1, 0]) # 模型预测结果
labels = torch.tensor([0, 1, 2, 2, 1]) # 真实标签
num_classes = 3 # 类别数
conf_matrix = confusion_matrix(preds, labels, num_classes)
print(conf_matrix)
```
这段代码定义了一个`confusion_matrix`函数,它接受模型的预测结果`preds`、真实标签`labels`和类别数`num_classes`作为输入,并返回一个大小为`num_classes x num_classes`的混淆矩阵。
在示例中,我们使用了一个简单的示例来演示代码的用法。`preds`是模型的预测结果,`labels`是真实标签,`num_classes`是类别数。运行代码后,将输出计算得到的混淆矩阵。