利用scikit-learn库中datasets包中鸢尾花数据集(load_iris)load_iris数据集训练一个决策树分类器,并实现鸢尾花分类任务。 (1)数据准备; (2)模型构建与训练; (3)模型评估; (4)结果可视化。
时间: 2024-10-08 10:18:05 浏览: 24
首先,让我们逐步完成这个鸢尾花分类任务:
**步骤1:数据准备**
```python
from sklearn import datasets
import pandas as pd
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data # 特征数据
y = iris.target # 类别标签
# 将数据转换为DataFrame,方便查看
df_iris = pd.DataFrame(data=X, columns=iris.feature_names)
df_iris['target'] = y
```
**步骤2:模型构建与训练**
```python
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
# 数据划分(70%用于训练,30%用于测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建并训练决策树分类器
clf = DecisionTreeClassifier() # 默认参数即可
clf.fit(X_train, y_train)
```
**步骤3:模型评估**
```python
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
# 预测测试集
y_pred = clf.predict(X_test)
# 评估指标
print("Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))
print("\nClassification Report:")
print(classification_report(y_test, y_pred))
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")
```
**步骤4:结果可视化**
```python
import matplotlib.pyplot as plt
import seaborn as sns
# 可视化混淆矩阵
sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, cmap='Blues')
plt.xlabel('Predicted Class')
plt.ylabel('True Class')
plt.title('Confusion Matrix')
plt.show()
# 决策树的可视化可以使用sklearn的plot_tree函数,但因为这通常是交互式可视化的例子,这里我们略去代码。
```