拥有一个鸢尾花的csv数据集,需要基于最优的k值将鸢尾花分类,并作出分类后的数据可视化图,请写出代码
时间: 2024-03-06 13:50:46 浏览: 69
好的,下面为您提供基于 KNN 算法对鸢尾花数据集进行分类并作出可视化的 Python 代码。
首先,我们需要将鸢尾花数据集导入到 Python 环境中,这里我们使用 scikit-learn 内置的鸢尾花数据集。
``` python
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
```
接下来,我们需要将数据集分成训练集和测试集,这里我们采用 80:20 的比例分割。
``` python
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
```
然后,我们需要确定最优的 k 值。这里我们采用交叉验证技术,使用 GridSearchCV 函数寻找最优的 k 值。
``` python
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
param_grid = {'n_neighbors': range(1, 31)}
knn = KNeighborsClassifier()
grid_search = GridSearchCV(knn, param_grid, cv=5)
grid_search.fit(X_train, y_train)
print("Best k: ", grid_search.best_params_['n_neighbors'])
```
接下来,我们可以使用最优的 k 值进行模型训练和测试,并计算测试集上的准确率。
``` python
from sklearn.metrics import accuracy_score
k = grid_search.best_params_['n_neighbors']
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Test accuracy: ", accuracy)
```
最后,我们可以使用 Matplotlib 库绘制分类后的数据可视化图。
``` python
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
colors = ['r', 'g', 'b']
markers = ['o', '^', 's']
for i in range(3):
ax.scatter(X_test[y_test==i, 0], X_test[y_test==i, 1], X_test[y_test==i, 2], c=colors[i], marker=markers[i])
ax.set_xlabel('sepal length (cm)')
ax.set_ylabel('sepal width (cm)')
ax.set_zlabel('petal length (cm)')
plt.show()
```
这段代码将测试集中的数据点按照真实标签绘制在三维坐标系中,不同类别的数据点使用不同的颜色和形状表示。
希望这段代码能够对您有所帮助!
阅读全文