基于PyTorch的CIFAR-10图像分类
时间: 2025-01-05 11:14:02 浏览: 6
### 使用PyTorch实现CIFAR-10图像分类
#### 数据准备与预处理
为了有效地训练卷积神经网络(CNN),数据的准备工作至关重要。对于CIFAR-10这样的小型彩色图片集合,通常会应用一些基本的数据增强技术来增加样本多样性并防止过拟合。例如,在加载数据之前可以设置转换函数来进行随机裁剪、水平翻转等操作[^2]。
```python
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
```
#### 构建模型结构
接下来定义一个简单的CNN架构用于实验目的。这里展示了一个基础版本的CNN设计思路;当然也可以尝试更复杂的预训练模型比如ResNet系列以获得更好的性能表现。
```python
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5) # Flatten tensor
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
```
另一种方法是在最后一层使用全局平均池化代替全连接层,这可以在一定程度上减少参数量从而加快收敛速度,不过可能会牺牲部分准确性[^3]。
#### 训练过程配置
完成上述两步之后就可以着手编写具体的训练逻辑了。这部分主要包括损失函数的选择(如交叉熵)、优化器设定以及循环迭代更新权重等内容。
```python
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('Finished Training')
```
阅读全文