用C++写KNN算法处理鸢尾花数据集
时间: 2023-10-24 19:16:12 浏览: 103
KNN算法实现鸢尾花数据集分类 C语言实现.zip
好的,下面是使用C语言实现KNN算法处理鸢尾花数据集的示例代码:
```c
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#define K 3
#define MAX_SAMPLES 150
#define MAX_FEATURES 4
typedef struct {
double features[MAX_FEATURES];
int label;
} Sample;
Sample samples[MAX_SAMPLES];
double distance(Sample s1, Sample s2) {
double dist = 0.0;
for (int i = 0; i < MAX_FEATURES; i++) {
dist += pow(s1.features[i] - s2.features[i], 2);
}
return sqrt(dist);
}
int compare(const void *a, const void *b) {
Sample *s1 = (Sample *)a;
Sample *s2 = (Sample *)b;
return (s1->label - s2->label);
}
int classify(Sample test) {
Sample nearest[K];
for (int i = 0; i < K; i++) {
nearest[i] = samples[i];
}
for (int i = K; i < MAX_SAMPLES; i++) {
double dist = distance(samples[i], test);
int max_idx = 0;
for (int j = 1; j < K; j++) {
if (distance(nearest[j], test) > distance(nearest[max_idx], test)) {
max_idx = j;
}
}
if (dist < distance(nearest[max_idx], test)) {
nearest[max_idx] = samples[i];
}
}
qsort(nearest, K, sizeof(Sample), compare);
int count[3] = {0};
for (int i = 0; i < K; i++) {
count[nearest[i].label]++;
}
int max_count = 0;
int max_label = -1;
for (int i = 0; i < 3; i++) {
if (count[i] > max_count) {
max_count = count[i];
max_label = i;
}
}
return max_label;
}
int main() {
FILE *fp = fopen("iris.data", "r");
if (fp == NULL) {
printf("Failed to open file\n");
return 1;
}
int idx = 0;
while (!feof(fp)) {
fscanf(fp, "%lf,%lf,%lf,%lf,%d\n", &samples[idx].features[0],
&samples[idx].features[1], &samples[idx].features[2], &samples[idx].features[3],
&samples[idx].label);
idx++;
}
fclose(fp);
Sample test;
test.features[0] = 5.1;
test.features[1] = 3.5;
test.features[2] = 1.4;
test.features[3] = 0.2;
int predicted_label = classify(test);
printf("Predicted label: %d\n", predicted_label);
return 0;
}
```
上述代码中,我们首先定义了一个 `Sample` 结构体来存储每个样本的特征和标签。我们还定义了一个 `distance` 函数来计算两个样本之间的距离。
在 `classify` 函数中,我们首先使用前 K 个样本作为最近邻。然后对于剩余的样本,我们计算它与测试样本之间的距离,并将其与当前 K 个最近邻中最远的样本进行比较。如果它与当前最远的样本更接近,则将其替换为最近邻之一。最后,我们使用最近邻的标签来预测测试样本的标签。
在 `main` 函数中,我们首先从文件中读取所有样本。然后我们手动创建一个测试样本,并使用 `classify` 函数来预测它的标签。
注意,上述代码中的 `iris.data` 文件应该包含鸢尾花数据集。每行应该包含四个特征和一个标签,用逗号分隔。在这个示例中,我们使用了一个硬编码的测试样本,但实际上我们可以从文件中读取测试样本并进行分类。
阅读全文