python 多分类预测混淆矩阵
时间: 2023-07-21 20:07:43 浏览: 111
python sklearn包——混淆矩阵、分类报告等自动生成方式
在Python中,可以使用scikit-learn库来进行多分类预测混淆矩阵的计算。
首先,需要导入所需的库和数据集:
```python
from sklearn.metrics import confusion_matrix
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data
y = iris.target
```
接着,按照惯例将数据集分为训练集和测试集:
```python
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
```
然后,使用决策树分类器进行训练和预测:
```python
clf = DecisionTreeClassifier().fit(X_train, y_train)
y_pred = clf.predict(X_test)
```
最后,使用`confusion_matrix`函数计算混淆矩阵:
```python
cm = confusion_matrix(y_test, y_pred)
print(cm)
```
输出结果如下:
```
[[13 0 0]
[ 0 15 1]
[ 0 0 9]]
```
其中,行表示真实值,列表示预测值。对角线上的数字表示预测正确的样本数,其余数字表示预测错误的样本数。例如,第二行第三列的数字1表示有1个真实值为1的样本被错误地预测为2。
如果需要更加可视化的混淆矩阵,可以使用`heatmap`函数:
```python
import seaborn as sns
sns.heatmap(cm, annot=True)
```
这样就可以得到一个彩色的热力图,其中不同的颜色代表不同数量级的样本数量,数字表示具体的数量。
阅读全文