python简单实现knn
时间: 2023-11-21 14:57:06 浏览: 35
以下是Python实现KNN算法的简单示例:
```python
import numpy as np
from collections import Counter
# 计算欧几里得距离
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2)**2))
class KNN:
def __init__(self, k=3):
self.k = k
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
predicted_labels = [self._predict(x) for x in X]
return np.array(predicted_labels)
def _predict(self, x):
# 计算距离
distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
# 获取k个最近邻居的标签
k_indices = np.argsort(distances)[:self.k]
k_nearest_labels = [self.y_train[i] for i in k_indices]
# 投票,选择最常见的标签
most_common = Counter(k_nearest_labels).most_common(1)
return most_common[0][0]
```
这个示例中,我们定义了一个KNN类,其中包含了fit、predict和_predict三个方法。fit方法用于训练模型,predict方法用于预测新的数据点的标签,_predict方法用于预测单个数据点的标签。在_predict方法中,我们首先计算待预测数据点与所有训练数据点之间的距离,然后选择k个最近邻居的标签,最后使用投票的方式选择最常见的标签作为预测结果。