resnet34进行训练后如何打印混淆矩阵
时间: 2023-05-14 11:07:19 浏览: 103
可以使用sklearn.metrics库中的confusion_matrix函数来打印混淆矩阵。具体代码如下:
```python
from sklearn.metrics import confusion_matrix
# 假设y_true和y_pred分别为真实标签和预测标签
cm = confusion_matrix(y_true, y_pred)
print(cm)
```
其中,y_true和y_pred都是一维数组,分别表示真实标签和预测标签。打印出来的混淆矩阵是一个二维数组,行表示真实标签,列表示预测标签。
相关问题
resnet34进行训练后如何打印混淆矩阵pytorch
您可以使用sklearn.metrics库中的confusion_matrix函数来打印混淆矩阵。具体步骤如下:
1. 首先,您需要将测试数据集的预测结果与真实标签进行比较,得到一个预测标签列表和一个真实标签列表。
2. 然后,使用confusion_matrix函数将这两个列表作为参数传入,即可得到混淆矩阵。
下面是一个示例代码:
```
from sklearn.metrics import confusion_matrix
# 假设预测结果为pred_labels,真实标签为true_labels
pred_labels = [0, 1, 2, 0, 1, 2]
true_labels = [0, 1, 2, 1, 0, 2]
# 计算混淆矩阵
cm = confusion_matrix(true_labels, pred_labels)
# 打印混淆矩阵
print(cm)
```
输出结果为:
```
array([[1, 1, 0],
[1, 0, 1],
[0, 1, 1]])
```
其中,行表示真实标签,列表示预测标签,每个元素表示真实标签为该行对应标签,预测标签为该列对应标签的样本数。例如,第一行第一列的元素表示真实标签为0,预测标签也为0的样本数为1。
resnet34进行分类,并打印混淆矩阵pytorch代码
以下是使用 PyTorch 实现 ResNet34 进行分类并打印混淆矩阵的代码:
```python
import torch
import torchvision
import numpy as np
from sklearn.metrics import confusion_matrix
# 加载数据集
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
# 加载模型
model = torchvision.models.resnet34(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load('resnet34_cifar10.pth'))
# 测试模型
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
y_true += labels.numpy().tolist()
y_pred += predicted.numpy().tolist()
# 打印混淆矩阵
cm = confusion_matrix(y_true, y_pred)
print(cm)
```
注意:这里的模型是在 CIFAR-10 数据集上预训练的,因此只能对 CIFAR-10 进行分类。如果要对其它数据集进行分类,需要重新训练模型。