from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score,confusion_matrix import numpy as np # 导入iris数据集 iris = load_iris() # 提取数据集中的特征数据 X = iris.data # 提取label y = iris.target # 划分训练集和测试集 X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=42,test_size=0.5,stratify=y) # 导入决策树,设置参数,最大深度为3,使用gini系数 tree = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42) # 拟合训练集 tree.fit(X_train,y_train) # 预测测试集 y_predict = tree.predict(X_test) # 计算准确率 accuracy = accuracy_score(y_test,y_predict) # 混淆矩阵 cm = confusion_matrix(y_test,y_predict) TP = np.diag(cm) FN = cm.sum(axis=1) - np.diag(cm) FP = cm.sum(axis=0) - np.diag(cm) TN = cm.sum() - (TP+FN+FP) recall = TP / (TP+FN) # Overall Accuracy OA = (TP+TN) / (TP+FN+FP+TN) print(OA) # Average Accuracy AA = recall.sum() / len(recall) print(AA)逐句详细分析一下
时间: 2024-02-27 20:51:59 浏览: 64
这段代码使用决策树算法对鸢尾花数据集进行训练和分类测试,并计算了准确率、混淆矩阵、召回率和平均准确率等指标。
具体的逐句分析如下:
1. 导入需要的库
```
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score,confusion_matrix
import numpy as np
```
首先导入需要的库,包括数据集iris、决策树分类器、数据集划分、准确率和混淆矩阵的计算库,以及Numpy库。
2. 加载数据集
```
iris = load_iris()
X = iris.data
y = iris.target
```
使用load_iris()函数加载iris数据集,并将特征数据和标签分别赋值给X和y。
3. 划分训练集和测试集
```
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=42,test_size=0.5,stratify=y)
```
使用train_test_split()函数将数据集划分为训练集和测试集。其中,random_state=42表示随机种子为42,test_size=0.5表示测试集占总样本的50%,stratify=y表示按照标签y进行分层抽样,以保证训练集和测试集中的类别比例相同。
4. 定义决策树分类器
```
tree = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
```
使用DecisionTreeClassifier()函数定义决策树分类器,其中criterion='gini'表示使用Gini系数作为划分标准,max_depth=3表示决策树的最大深度为3,random_state=42表示随机种子为42。
5. 训练模型
```
tree.fit(X_train,y_train)
```
使用fit()函数对训练集进行拟合,以训练决策树模型。
6. 对测试集进行预测
```
y_predict = tree.predict(X_test)
```
使用predict()函数对测试集进行预测,得到预测结果y_predict。
7. 计算准确率和混淆矩阵
```
accuracy = accuracy_score(y_test,y_predict)
cm = confusion_matrix(y_test,y_predict)
```
使用accuracy_score()函数计算准确率,使用confusion_matrix()函数计算混淆矩阵。
8. 计算召回率和平均准确率
```
TP = np.diag(cm)
FN = cm.sum(axis=1) - np.diag(cm)
FP = cm.sum(axis=0) - np.diag(cm)
TN = cm.sum() - (TP+FN+FP)
recall = TP / (TP+FN)
OA = (TP+TN) / (TP+FN+FP+TN)
AA = recall.sum() / len(recall)
print(OA)
print(AA)
```
根据混淆矩阵计算True Positive(TP)、False Negative(FN)、False Positive(FP)和True Negative(TN),以计算召回率(recall)和平均准确率(AA)。其中,召回率表示正确预测的样本数占总样本数的比例,平均准确率表示召回率的平均值。最后,打印输出OA和AA的值。
阅读全文