resnet18训练cifa可视化
时间: 2023-10-27 08:26:50 浏览: 153
模型训练(6) ResNet18模型可视化
在训练 ResNet-18 模型用于 CIFAR-10 数据集时,可以使用 Matplotlib 库进行训练损失和准确率的可视化。具体实现代码如下:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# 定义数据增强和加载数据集
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=2)
# 定义 ResNet-18 模型和损失函数
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# 训练模型
train_losses = []
train_accs = []
for epoch in range(50):
running_loss = 0.0
running_acc = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
running_acc += (predicted == labels).sum().item()
if i % 100 == 99:
train_losses.append(running_loss / 100)
train_accs.append(running_acc / 12800)
print('[%d, %5d] loss: %.3f, acc: %.3f' %
(epoch + 1, i + 1, running_loss / 100, running_acc / 12800))
running_loss = 0.0
running_acc = 0.0
# 绘制训练损失和准确率曲线
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(range(len(train_losses)), train_losses, label='Training Loss')
ax2.plot(range(len(train_accs)), train_accs, label='Training Accuracy', color='orange')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Training Loss')
ax2.set_ylabel('Training Accuracy')
ax1.set_title('Training Loss and Accuracy of ResNet-18 on CIFAR-10')
ax1.legend(loc='upper left')
ax2.legend(loc='upper right')
plt.show()
```
在这个例子中,我们使用 PyTorch 中的 ResNet-18 模型作为图像分类器,使用 CIFAR-10 数据集进行训练。我们首先定义了数据增强和加载数据集的代码,然后定义了 ResNet-18 模型、损失函数、优化器等。在训练过程中,我们记录了每个 epoch 的训练损失和准确率,并将其存储在 `train_losses` 和 `train_accs` 列表中。最后,我们使用 Matplotlib 库绘制了训练损失和准确率曲线。您可以根据自己的需求和实际情况修改代码中的参数和数据来实现自己的可视化效果。
阅读全文