knn算法python代码
时间: 2023-07-22 08:32:11 浏览: 51
下面是一个简单的KNN算法的Python实现代码:
```python
import numpy as np
from collections import Counter
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2)**2))
class KNN:
def __init__(self, k=3):
self.k = k
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
predicted_labels = [self._predict(x) for x in X]
return np.array(predicted_labels)
def _predict(self, x):
# Compute distances between x and all examples in the training set
distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
# Sort by distance and return indices of the first k neighbors
k_idx = np.argsort(distances)[:self.k]
# Extract the labels of the k nearest neighbor training samples
k_neighbor_labels = [self.y_train[i] for i in k_idx]
# return the most common class label
most_common = Counter(k_neighbor_labels).most_common(1)
return most_common[0][0]
```
该实现包括以下步骤:
1. 计算欧几里得距离函数
2. KNN类的初始化,包括K值和fit方法(用于训练模型)
3. 预测方法(predict),用于预测新的数据点的类别
4. 预测方法的实现,包括计算与测试集中所有点的距离,选择最近的k个点并提取它们的标签,最后返回最常见的标签值。