第三步:模型评估 评估模型的性能,获取相关性能指标。编写代码
时间: 2024-11-16 08:19:44 浏览: 32
西南交通大学 机器学习 实验3.docx
在基于图片的安全帽佩戴识别项目的第三步中,模型评估是关键环节之一。以下是如何评估模型性能并获取相关性能指标的具体步骤和示例代码:
### 步骤概述
1. **加载测试数据**:从测试集中加载图像数据及其标签。
2. **模型预测**:使用训练好的模型对测试数据进行预测。
3. **计算性能指标**:计算准确率、精确率、召回率、F1分数等指标。
4. **可视化结果**:生成混淆矩阵和ROC曲线等图表,以便直观地评估模型性能。
### 示例代码
假设你已经有一个训练好的模型 `model` 和一个测试数据集 `test_dataset`。
#### 1. 加载测试数据
```python
import torch
from torch.utils.data import DataLoader
# 假设 test_dataset 是 PyTorch 的 Dataset 对象
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
```
#### 2. 模型预测
```python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
```
#### 3. 计算性能指标
```python
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
# 如果是多类分类,可以使用平均值
# precision = precision_score(all_labels, all_preds, average='weighted')
# recall = recall_score(all_labels, all_preds, average='weighted')
# f1 = f1_score(all_labels, all_preds, average='weighted')
conf_matrix = confusion_matrix(all_labels, all_preds)
print('Confusion Matrix:')
print(conf_matrix)
# 如果有概率输出,可以计算 ROC AUC
probs = torch.softmax(outputs, dim=1).cpu().numpy()
roc_auc = roc_auc_score(all_labels, probs[:, 1])
print(f'ROC AUC: {roc_auc:.4f}')
```
#### 4. 可视化结果
```python
import matplotlib.pyplot as plt
import seaborn as sns
# 绘制混淆矩阵
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
# 绘制 ROC 曲线
from sklearn.metrics import roc_curve
fpr, tpr, _ = roc_curve(all_labels, probs[:, 1])
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()
```
### 总结
以上代码展示了如何加载测试数据、进行模型预测、计算性能指标并可视化结果。这些步骤有助于全面评估模型的性能,确保其在实际应用中的有效性和可靠性。
阅读全文