knn算法代码
时间: 2023-07-03 16:20:01 浏览: 97
以下是一个简单的knn算法的Python代码示例:
```python
import numpy as np
from collections import Counter
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2)**2))
class KNN:
def __init__(self, k=3):
self.k = k
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
predicted_labels = [self._predict(x) for x in X]
return np.array(predicted_labels)
def _predict(self, x):
distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
k_indices = np.argsort(distances)[:self.k]
k_nearest_labels = [self.y_train[i] for i in k_indices]
most_common = Counter(k_nearest_labels).most_common(1)
return most_common[0][0]
```
在这个代码中,我们定义了一个KNN类,其中包含以下方法:
- `__init__(self, k=3)`:初始化函数,其中k是KNN算法中的k值,默认为3。
- `fit(self, X, y)`:训练函数,其中X是训练数据集,y是对应的标签。
- `_predict(self, x)`:预测函数,其中x是一个样本,用于预测该样本的标签。
- `predict(self, X)`:批量预测函数,其中X是一批数据,用于批量预测这些数据对应的标签。
在这个代码中,我们使用了欧氏距离计算样本之间的距离,并使用Counter类来计算k个最近邻样本中出现最多的标签。
阅读全文