c++手写数字识别knn完整
时间: 2023-08-03 07:39:44 浏览: 206
这里给出一个简单的C++ KNN数字识别的实现代码,可以对MNIST数据集进行测试:
```
#include <iostream>
#include <fstream>
#include <vector>
#include <cmath>
#include <cstring>
using namespace std;
const int k = 3; // KNN算法的k值
const int image_size = 784; // MNIST图片大小(28*28)
const int train_num = 60000; // 训练集大小
const int test_num = 10000; // 测试集大小
struct Data {
vector<int> features;
int label;
};
Data train_data[train_num];
Data test_data[test_num];
// 读取MNIST数据集
void read_mnist(string filename, Data *data, int num) {
ifstream file(filename, ios::binary);
if (!file) {
cerr << "Failed to open file: " << filename << endl;
exit(1);
}
unsigned char buffer[4];
file.read((char*)buffer, 4);
int magic_number = (buffer[0] << 24) + (buffer[1] << 16) + (buffer[2] << 8) + buffer[3];
file.read((char*)buffer, 4);
int n_images = (buffer[0] << 24) + (buffer[1] << 16) + (buffer[2] << 8) + buffer[3];
for (int i = 0; i < num; i++) {
vector<int> features(image_size);
file.read((char*)features.data(), image_size);
data[i].features = features;
}
file.close();
}
// KNN算法
int knn(Data *train_data, Data *test_data) {
int cnt[10] = {0};
int dist[train_num];
for (int i = 0; i < train_num; i++) {
int d = 0;
for (int j = 0; j < image_size; j++) {
int diff = train_data[i].features[j] - test_data->features[j];
d += diff * diff;
}
dist[i] = d;
}
for (int i = 0; i < k; i++) {
int min_dist = -1, min_index = -1;
for (int j = 0; j < train_num; j++) {
if (dist[j] != -1 && (min_dist == -1 || dist[j] < min_dist)) {
min_dist = dist[j];
min_index = j;
}
}
cnt[train_data[min_index].label]++;
dist[min_index] = -1;
}
int max_cnt = -1, max_index = -1;
for (int i = 0; i < 10; i++) {
if (cnt[i] > max_cnt) {
max_cnt = cnt[i];
max_index = i;
}
}
return max_index;
}
int main() {
string train_image_filename = "train-images.idx3-ubyte";
string train_label_filename = "train-labels.idx1-ubyte";
string test_image_filename = "t10k-images.idx3-ubyte";
string test_label_filename = "t10k-labels.idx1-ubyte";
read_mnist(train_image_filename, train_data, train_num);
read_mnist(train_label_filename, train_data, train_num);
read_mnist(test_image_filename, test_data, test_num);
read_mnist(test_label_filename, test_data, test_num);
int correct = 0;
for (int i = 0; i < test_num; i++) {
int predict_label = knn(train_data, &test_data[i]);
if (predict_label == test_data[i].label) {
correct++;
}
}
double accuracy = (double)correct / test_num;
cout << "Accuracy: " << accuracy << endl;
return 0;
}
```
这个代码利用KNN算法对MNIST数据集进行数字识别。首先通过`read_mnist`函数读取MNIST数据集,然后利用`knn`函数对测试集中的每个样本进行预测,最后计算出准确率并输出。
需要注意的是,这个代码只是一个简单的实现,没有使用任何优化算法,因此在实际使用中识别效果可能不理想。如果想要提高识别准确率,可以使用更复杂的算法或者加入更多的特征。
阅读全文