svm多分类混淆矩阵绘图
时间: 2024-12-31 16:17:15 浏览: 4
### 绘制SVM多分类的混淆矩阵
为了绘制支持向量机(SVM)模型在多分类任务中的混淆矩阵,可以按照如下方法操作。首先加载必要的库并准备数据集:
```python
from sklearn import svm
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
```
接着定义函数来训练SVM模型并对测试集做出预测:
```python
def train_and_predict(X_train, y_train, X_test):
clf = svm.SVC(kernel='linear', C=1, decision_function_shape='ovo') # 使用一对一策略处理多类问题
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)
return predictions
```
之后创建用于展示混淆矩阵的辅助函数:
```python
def plot_confusion_mat(y_true, y_pred, class_names):
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
fig, ax = plt.subplots(figsize=(8, 6))
disp.plot(ax=ax, cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.show()
```
最后组合上述组件完成整个流程:
```python
# 加载鸢尾花数据集作为例子
data = load_iris()
X = data.data
y = data.target
class_names = data.target_names
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 训练模型并获取预测结果
predictions = train_and_predict(X_train, y_train, X_test)
# 展示混淆矩阵
plot_confusion_mat(y_test, predictions, class_names)
```
此过程展示了如何利用`sklearn`库中的工具构建一个多类别分类器,并通过可视化手段评估其性能表现。
阅读全文