pytorch实现CIFAR-10分类
时间: 2024-11-03 08:01:00 浏览: 33
利用pytorch实现对CIFAR-10数据集的分类
PyTorch是一个流行的深度学习框架,可以用于实现CIFAR-10图像分类任务。CIFAR-10是一个常用的小型计算机视觉数据集,包含60,000张32x32像素的彩色图片,分为10个类别。
以下是使用PyTorch构建一个基础卷积神经网络(CNN)模型来对CIFAR-10进行分类的基本步骤:
1. **导入库**:
```python
import torch
import torchvision
import torchvision.transforms as transforms
```
2. **数据预处理**:
```python
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)
```
3. **定义模型**:
```python
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2),
# 添加更多卷积层、池化层和全连接层...
torch.nn.Linear(32 * 8 * 8, 10) # 最后一个线性层对应于10个类别的分类
)
```
4. **损失函数和优化器**:
```python
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
```
5. **训练过程**:
```python
num_epochs = 10
for epoch in range(num_epochs):
for images, labels in train_loader:
optimizer.zero_grad() # 梯度清零
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 训练阶段结束后的评估
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Epoch {epoch+1}/{num_epochs} - Test Accuracy: {correct / total}")
```
阅读全文