knn算法python代码
时间: 2023-10-14 12:27:18 浏览: 29
下面是一个简单的KNN算法的Python代码示例:
```python
import numpy as np
class KNN:
def __init__(self, k=3):
self.k = k
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
distances = self.compute_distances(X)
return self.predict_labels(distances)
def compute_distances(self, X):
num_test = X.shape[0]
num_train = self.X_train.shape[0]
distances = np.zeros((num_test, num_train))
for i in range(num_test):
for j in range(num_train):
distances[i, j] = np.sqrt(np.sum((X[i] - self.X_train[j])**2))
return distances
def predict_labels(self, distances):
num_test = distances.shape[0]
y_pred = np.zeros(num_test)
for i in range(num_test):
k_closest = self.y_train[np.argsort(distances[i])[:self.k]]
y_pred[i] = np.argmax(np.bincount(k_closest))
return y_pred
```
这个简单的KNN类实现了`fit`和`predict`方法。`fit`方法用于拟合训练数据,而`predict`方法用于预测测试数据的标签。在这个实现中,我们使用欧几里得距离作为距离的度量方式。