pytorch混淆矩阵代码
时间: 2023-07-29 20:11:12 浏览: 99
混淆矩阵
下面是一个使用PyTorch创建混淆矩阵的示例代码:
```python
import torch
from sklearn.metrics import confusion_matrix
# 定义真实标签和预测标签
y_true = torch.tensor([1, 0, 1, 0, 1, 1, 0, 1])
y_pred = torch.tensor([1, 1, 0, 0, 1, 0, 1, 1])
# 计算混淆矩阵
cm = confusion_matrix(y_true.numpy(), y_pred.numpy())
print(cm)
```
这段代码首先导入了必要的库,然后定义了真实标签`y_true`和预测标签`y_pred`。接着使用`confusion_matrix`函数计算混淆矩阵,并将结果打印出来。
请注意,这里使用了`sklearn.metrics`中的`confusion_matrix`函数,因为PyTorch本身并没有提供内置的混淆矩阵计算函数。因此,在运行这段代码之前,请确保你已经安装了scikit-learn库。
阅读全文