from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score,confusion_matrix import numpy as np # 导入iris数据集 iris = load_iris() # 提取数据集中的特征数据 X = iris.data # 提取label y = iris.target # 划分训练集和测试集 X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=42,test_size=0.5,stratify=y) # 导入决策树,设置参数,最大深度为3,使用gini系数 tree = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42) # 拟合训练集 tree.fit(X_train,y_train) # 预测测试集 y_predict = tree.predict(X_test) # 计算准确率 accuracy = accuracy_score(y_test,y_predict) # 混淆矩阵 cm = confusion_matrix(y_test,y_predict) TP = np.diag(cm) FN = cm.sum(axis=1) - np.diag(cm) FP = cm.sum(axis=0) - np.diag(cm) TN = cm.sum() - (TP+FN+FP) recall = TP / (TP+FN) # Overall Accuracy OA = (TP+TN) / (TP+FN+FP+TN) print(OA) # Average Accuracy AA = recall.sum() / len(recall) print(AA)将每个部分分析一下
时间: 2024-02-27 14:51:47 浏览: 74
Decision-tree-code-using-sklearn
好的,下面是代码的分析:
1. 导入需要的库:
```
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score,confusion_matrix
import numpy as np
```
这里导入了scikit-learn的iris数据集、决策树分类器、数据集划分函数、准确率和混淆矩阵评估指标以及numpy库。
2. 提取数据集中的特征数据和标签数据:
```
iris = load_iris()
X = iris.data
y = iris.target
```
这里使用load_iris函数导入iris数据集,然后分别提取数据集中的特征数据和标签数据。
3. 划分训练集和测试集:
```
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=42,test_size=0.5,stratify=y)
```
这里使用train_test_split函数将数据集划分为训练集和测试集,并且设置了随机种子random_state为42,测试集占比为50%,并且按照标签y进行分层划分。
4. 定义决策树分类器,设置参数:
```
tree = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
```
这里定义了一个决策树分类器,并且设置了criterion为'gini'表示使用gini系数作为分裂标准,最大深度为3,随机种子为42。
5. 拟合训练集:
```
tree.fit(X_train,y_train)
```
这里使用fit函数拟合训练集,训练出决策树模型。
6. 预测测试集:
```
y_predict = tree.predict(X_test)
```
这里使用predict函数预测测试集的标签。
7. 计算准确率和混淆矩阵:
```
accuracy = accuracy_score(y_test,y_predict)
cm = confusion_matrix(y_test,y_predict)
```
这里使用accuracy_score函数计算准确率,使用confusion_matrix函数计算混淆矩阵。
8. 计算分类指标:
```
TP = np.diag(cm)
FN = cm.sum(axis=1) - np.diag(cm)
FP = cm.sum(axis=0) - np.diag(cm)
TN = cm.sum() - (TP+FN+FP)
recall = TP / (TP+FN)
OA = (TP+TN) / (TP+FN+FP+TN)
AA = recall.sum() / len(recall)
```
这里使用numpy库计算分类指标,包括真正例TP、假负例FN、假正例FP、真负例TN、召回率recall、总准确率OA和平均准确率AA。
9. 输出结果:
```
print(OA)
print(AA)
```
最后输出总准确率OA和平均准确率AA。
阅读全文