vgg11 cifar10
时间: 2024-12-26 16:25:07 浏览: 3
### 使用VGG11模型进行CIFAR-10数据集的图像分类任务
#### 准备工作
为了成功实现这一目标,需先安装并导入必要的库。技术栈主要包括Python和PyTorch。确保环境中已正确配置这些工具。
```python
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
```
#### 加载与预处理CIFAR-10数据集
针对CIFAR-10的数据加载部分,采用`torchvision.datasets.CIFAR10`接口,并应用适当变换以提高性能。这里不仅限于基本转换操作,还应考虑加入数据增强手段来扩充样本多样性[^2]。
```python
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
```
#### 定义VGG11模型
利用PyTorch内置函数创建VGG11实例。考虑到计算资源限制以及实际需求调整最后几层结构可能是有益的选择之一。对于初始化权重方面,则可以借鉴之前的经验教训采取更合理的方案[^3]。
```python
class VGG11(nn.Module):
def __init__(self, num_classes=10):
super(VGG11, self).__init__()
vgg_model = models.vgg11(pretrained=False).features
# 自定义分类头
classifier = list(models.vgg11().classifier.children())[:-1]
classifier.extend([nn.Linear(4096, num_classes)])
self.features = vgg_model
self.classifier = nn.Sequential(*classifier)
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
model = VGG11()
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
```
#### 设置优化器与损失函数
选择合适的优化算法至关重要。鉴于先前遇到过因不当设定学习率而导致的问题,在此推荐使用Adam作为默认选项的同时降低初始学习率为0.005。
```python
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
```
#### 训练过程
编写循环迭代读取批次数据完成前向传播、反向传播更新参数的过程。期间可定期保存最佳模型防止过拟合现象发生。
```python
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
best_acc = 0.
epochs = 200
for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
running_loss += loss.item()
scheduler.step()
acc = 100 * float(correct / total)
print(f'[Epoch {epoch + 1}] Loss: {running_loss/(i+1)}, Accuracy: {acc}%')
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), './checkpoint/vgg11_best.pth')
print('Finished Training.')
```
#### 测试阶段
当训练完成后,可以通过加载最优权值文件来进行最终评估。
```python
model.load_state_dict(torch.load('./checkpoint/vgg11_best.pth'))
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
```
阅读全文