k近邻算法python
时间: 2023-07-10 22:35:05 浏览: 40
K近邻算法是一种简单而有效的机器学习算法,通常用于分类和回归问题。它的基本思想是根据已知样本的特征,将未知样本分类到与其最近的K个已知样本中。
下面是一个简单的Python实现:
```
import numpy as np
from collections import Counter
# 定义KNN分类器
class KNNClassifier:
def __init__(self, k):
self.k = k
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
predicted_labels = [self._predict(x) for x in X]
return np.array(predicted_labels)
def _predict(self, x):
# 计算距离
distances = [np.sqrt(np.sum((x - x_train)**2)) for x_train in self.X_train]
# 获得最近的K个样本
k_indices = np.argsort(distances)[:self.k]
# 获得最近的K个样本的标签
k_labels = [self.y_train[i] for i in k_indices]
# 投票获得预测标签
most_common = Counter(k_labels).most_common(1)
return most_common[0][0]
```
其中,fit()函数用于训练模型,predict()函数用于预测未知样本的标签,_predict()函数用于计算单个未知样本的预测标签。在_predict()函数中,首先计算未知样本与已知样本之间的距离,然后获得最近的K个样本的标签,最后通过投票的方式确定预测标签。
阅读全文