numpy实现knn算法鸢尾花
时间: 2023-10-13 11:15:54 浏览: 83
KNN(K-Nearest Neighbors)是一种基本的分类和回归算法,其中KNN分类算法根据K个最近邻居的标签来预测新数据点的标签。在这里,我们将使用numpy实现KNN算法来分类鸢尾花数据集。
首先,我们需要加载鸢尾花数据集,可以使用sklearn库中的load_iris()函数来加载数据集。代码如下:
```python
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
```
然后,我们需要将数据集分为训练数据和测试数据,通常将数据集的70%用于训练,30%用于测试。我们可以使用sklearn库中的train_test_split()函数来实现。代码如下:
```python
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
```
接下来,我们需要计算测试数据集中每个数据点与训练数据集中每个数据点之间的距离。我们可以使用欧几里得距离公式来计算数据点之间的距离。代码如下:
```python
import numpy as np
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2)**2))
class KNN:
def __init__(self, k=3):
self.k = k
def fit(self, X, y):
self.X_train = X
self.y_train = y
def predict(self, X):
y_pred = [self._predict(x) for x in X]
return np.array(y_pred)
def _predict(self, x):
distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
k_idx = np.argsort(distances)[:self.k]
k_neighbor_labels = [self.y_train[i] for i in k_idx]
most_common = Counter(k_neighbor_labels).most_common(1)
return most_common[0][0]
```
在KNN类中,我们定义了fit()函数来训练模型,predict()函数来预测测试数据集中的标签,_predict()函数来计算每个测试数据点的标签。在_predict()函数中,我们计算测试数据点与每个训练数据点之间的距离,选择k个最近邻居,并使用Counter函数来计算最常见的标签并进行预测。
最后,我们可以使用以上定义的KNN类来预测测试数据集中的标签。代码如下:
```python
from collections import Counter
k = 3
clf = KNN(k=k)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracy = np.sum(y_pred == y_test) / len(y_test)
print("Accuracy:", accuracy)
```
输出结果为:
```
Accuracy: 1.0
```
这意味着我们的模型在测试数据集中的所有数据点上都有100%的准确率。
阅读全文