knn算法实现鸢尾花分类
时间: 2023-06-03 22:03:26 浏览: 315
K-NN (K-Nearest Neighbors) 算法是一种常见的机器学习算法之一,它可以用于分类和回归。对于鸢尾花分类问题,使用 K-NN 算法可以根据鸢尾花的花萼长度、花萼宽度、花瓣长度和花瓣宽度等特征,将鸢尾花分为 Setosa、Versicolour 和 Virginica 三个种类。K-NN 算法首先会将训练集中所有数据与测试集中的数据进行距离计算,然后将距离最近的 K 个样本的类别进行统计,最后以众数作为测试集样本的分类标签。
相关问题
KNN算法实现鸢尾花分类
KNN(K-Nearest Neighbor)算法是一种基本的分类和回归算法,其核心思想是通过找到与待分类样本最近的K个训练集样本,来确定待分类样本所属的类别。在这里,我们将使用KNN算法来实现鸢尾花的分类。
首先,我们需要导入必要的库和数据集:
```python
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
iris = load_iris()
X = iris.data
y = iris.target
```
然后,我们将数据集分为训练集和测试集:
```python
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
```
接下来,我们需要选择K值,并训练KNN模型:
```python
k = 3
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X_train, y_train)
```
最后,我们使用测试集来评估模型的性能:
```python
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
```
完整的代码如下:
```python
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
k = 3
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
```
运行结果将会输出模型在测试集上的准确率。
knn算法实现鸢尾花分类C语言
以下是基于C语言实现的KNN算法鸢尾花分类代码:
```c
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#define K 3 // 定义K值
// 定义鸢尾花的数据结构
typedef struct {
double sepallength; // 花萼长度
double sepalwidth; // 花萼宽度
double petallength; // 花瓣长度
double petalwidth; // 花瓣宽度
char class[20]; // 类别
} Iris;
// 读取数据集
void read_data(char* filename, Iris* dataset, int* n) {
FILE* fp = fopen(filename, "r");
if (fp == NULL) {
printf("Open file %s failed!\n", filename);
exit(1);
}
char buf[1024];
int i = 0;
while (fgets(buf, 1024, fp)) {
sscanf(buf, "%lf,%lf,%lf,%lf,%s", &dataset[i].sepallength, &dataset[i].sepalwidth, \
&dataset[i].petallength, &dataset[i].petalwidth, dataset[i].class);
i++;
}
*n = i;
fclose(fp);
}
// 计算两点之间的距离
double distance(Iris* p, Iris* q) {
return sqrt(pow(p->sepallength - q->sepallength, 2) + pow(p->sepalwidth - q->sepalwidth, 2) + \
pow(p->petallength - q->petallength, 2) + pow(p->petalwidth - q->petalwidth, 2));
}
// 查找K个最近邻居
void find_k_neighbors(Iris* dataset, int n, Iris* test, Iris** neighbors) {
double dist;
double max_dist = 0.0;
int max_index = 0;
for (int i = 0; i < K; i++) {
neighbors[i] = &dataset[i];
dist = distance(neighbors[i], test);
if (dist > max_dist) {
max_dist = dist;
max_index = i;
}
}
for (int i = K; i < n; i++) {
dist = distance(&dataset[i], test);
if (dist < max_dist) {
neighbors[max_index] = &dataset[i];
max_dist = dist;
for (int j = 0; j < K; j++) {
if (distance(neighbors[j], test) > max_dist) {
max_dist = distance(neighbors[j], test);
max_index = j;
}
}
}
}
}
// 计算类别出现次数
char* get_class(Iris** neighbors) {
int class_count[3] = {0, 0, 0};
for (int i = 0; i < K; i++) {
if (strcmp(neighbors[i]->class, "Iris-setosa") == 0) {
class_count[0]++;
} else if (strcmp(neighbors[i]->class, "Iris-versicolor") == 0) {
class_count[1]++;
} else {
class_count[2]++;
}
}
int max_count = 0;
int max_index = 0;
for (int i = 0; i < 3; i++) {
if (class_count[i] > max_count) {
max_count = class_count[i];
max_index = i;
}
}
if (max_index == 0) {
return "Iris-setosa";
} else if (max_index == 1) {
return "Iris-versicolor";
} else {
return "Iris-virginica";
}
}
int main() {
Iris dataset[150]; // 数据集
Iris test_data; // 测试数据
Iris* neighbors[K]; // 最近邻居
int n; // 数据集的大小
read_data("iris.data", dataset, &n); // 读取数据集
// 手动输入测试数据
printf("Please input the test data (sepal length, sepal width, petal length, petal width):\n");
scanf("%lf%lf%lf%lf", &test_data.sepallength, &test_data.sepalwidth, &test_data.petallength, &test_data.petalwidth);
find_k_neighbors(dataset, n, &test_data, neighbors); // 查找最近邻居
printf("The class of the test data is %s.\n", get_class(neighbors)); // 输出分类结果
return 0;
}
```
注意:本代码中的数据集文件"iris.data"需要自行下载。
阅读全文