knn算法python代码
时间: 2023-07-22 12:04:36 浏览: 82
下面是一个简单的KNN算法的Python代码实现,其中包含了KNN算法的训练和预测过程:
```
import numpy as np
class KNN:
def __init__(self, k=3):
self.k = k
def fit(self, X_train, y_train):
self.X_train = X_train
self.y_train = y_train
def predict(self, X_test):
predictions = []
for i in range(len(X_test)):
distances = np.sum((self.X_train - X_test[i])**2, axis=1)
nearest_neighbors = np.argsort(distances)[:self.k]
neighbor_labels = self.y_train[nearest_neighbors]
counts = np.bincount(neighbor_labels)
predictions.append(np.argmax(counts))
return predictions
```
其中,fit()方法用于训练模型,需要传入训练数据X_train和标签数据y_train;predict()方法用于进行预测,需要传入测试数据X_test。在预测过程中,对于每个测试样本,首先计算其与所有训练样本的距离,选取距离最近的k个样本,统计它们的标签信息,并返回出现次数最多的标签作为预测结果。
阅读全文