knn算法python代码
时间: 2023-10-12 17:20:34 浏览: 26
以下是一个简单的KNN算法的Python代码实现:
```python
import numpy as np
class KNN:
def __init__(self, k):
self.k = k
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
distances = []
targets = []
for i in range(len(self.X_train)):
distance = np.sqrt(np.sum(np.square(X - self.X_train[i, :])))
distances.append([distance, i])
distances = sorted(distances)
for i in range(self.k):
index = distances[i][1]
targets.append(self.y_train[index])
return Counter(targets).most_common(1)[0][0]
```
这个KNN类包含三个方法:
- `__init__`:初始化KNN算法的K值;
- `fit`:将训练数据集X和y存储在类中,以备后续的预测使用;
- `predict`:根据输入的测试数据集X,预测其对应的标签。
其中,`predict`方法包含了KNN算法的核心实现:
- 对于每个测试数据点,计算它与所有训练数据点的距离;
- 将距离按照从小到大排序;
- 取前K个距离最小的训练数据点的标签,并统计它们出现的次数;
- 返回出现次数最多的标签作为预测结果。
需要注意的是,这个代码实现中使用了Python的NumPy库来进行向量和矩阵的运算,以提高计算效率。