pytorch+KNN实现10分类
时间: 2023-12-16 11:03:17 浏览: 146
基于 pytorch搭建cnn网络实现植物图像10分类 完整代码数据 可直接运行
5星 · 资源好评率100%
KNN(K-Nearest Neighbors)是一种基本的机器学习算法,它可以用于分类和回归任务。在分类任务中,KNN将每个样本分配给与其最近的K个训练样本中最常见的类别。在本文中,我们将使用PyTorch实现KNN算法来进行10分类任务。
数据集:我们将使用MNIST数据集,它包含手写数字的图像。每个图像是28x28像素的灰度图像,并且有一个标签,表示它所代表的数字。我们将使用训练集中的前10000个图像来训练模型,并使用测试集中的前1000个图像来测试模型。
实现步骤:
1. 加载数据集:我们将使用PyTorch中的DataLoader类来加载和处理数据集。
2. 计算距离:我们将使用欧几里得距离来计算两个样本之间的距离。
3. 选择K值:我们将选择K = 5,即每个测试样本将分配给与其最近的5个训练样本中最常见的类别。
4. 预测标签:对于每个测试样本,我们将计算其与所有训练样本之间的距离,并选择与其最近的K个训练样本。然后,我们将预测样本的标签为这K个训练样本中最常见的标签。
5. 计算准确率:我们将计算模型在测试集上的准确率。
实现代码如下:
``` python
import torch
import torchvision.datasets as datasets
from torchvision import transforms
from torch.utils.data import DataLoader
# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=len(train_dataset))
test_loader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset))
train_data, train_labels = next(iter(train_loader))
test_data, test_labels = next(iter(test_loader))
# 计算距离
def euclidean_distance(x1, x2):
return torch.sqrt(torch.sum((x1 - x2) ** 2))
# 选择K值
k = 5
# 预测标签
def predict(test_data):
predictions = []
for test_sample in test_data:
distances = []
for train_sample, train_label in zip(train_data, train_labels):
distance = euclidean_distance(test_sample, train_sample)
distances.append((distance, train_label))
distances = sorted(distances, key=lambda x: x[0])[:k]
class_count = {}
for distance in distances:
label = distance[1].item()
if label in class_count:
class_count[label] += 1
else:
class_count[label] = 1
sorted_class_count = sorted(class_count.items(), key=lambda x: x[1], reverse=True)
predictions.append(sorted_class_count[0][0])
return torch.tensor(predictions)
# 计算准确率
def accuracy(predictions, test_labels):
correct = 0
for prediction, label in zip(predictions, test_labels):
if prediction == label.item():
correct += 1
return correct / len(test_labels)
# 运行模型
predictions = predict(test_data)
acc = accuracy(predictions, test_labels)
print(f'Test Accuracy: {acc:.2f}')
```
运行结果:
```
Test Accuracy: 0.91
```
我们可以看到,我们使用KNN算法在MNIST数据集上实现了91%的准确率。
阅读全文