基于以下代码框架实现KNN核心算法,并对鸢尾花数据集进行分类。要求:不能调用sklearn中现成的KNN,其他的包都可以用
时间: 2024-12-16 20:28:37 浏览: 12
基于Python编程,我们可以手动实现K近邻(K-Nearest Neighbors, KNN)算法的核心部分。这里是一个简单的例子,我们将使用numpy处理数据,以及matplotlib用于可视化结果:
```python
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
# 假设你已经有了鸢尾花数据集(可以用Pandas读取csv文件)
# 例如:
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
def euclidean_distance(point1, point2):
return np.sqrt(np.sum((point1 - point2)**2))
def knn_classifier(X_train, y_train, test_point, k=3):
distances = [(euclidean_distance(test_point, train_point), label)
for train_point, label in zip(X_train, y_train)]
distances.sort(key=lambda x: x[0]) # 按距离排序
nearest_k_labels = [distances[i][1] for i in range(k)]
most_common_label = Counter(nearest_k_labels).most_common(1)[0][0]
return most_common_label
# 预测鸢尾花类别
knn_prediction = knn_classifier(X, y, X[0], k=3)
print(f"预测第0个样本属于类别:{knn_prediction}")
# 可视化决策边界 (假设二维数据)
if X.shape[1] == 2: # 仅适用于二维数据
plot_data(X, y, [knn_prediction])
plt.show()
def plot_data(data, labels, predicted_label):
pos_class = data[labels == predicted_label]
neg_class = data[labels != predicted_label]
plt.scatter(pos_class[:, 0], pos_class[:, 1], c='blue', marker='o')
plt.scatter(neg_class[:, 0], neg_class[:, 1], c='red', marker='x')
阅读全文