resent训练可视化
时间: 2023-10-30 10:08:13 浏览: 101
resnet50预训练权重文件和cfg文件
ResNet 是一种非常流行的深度卷积神经网络,可以用于图像分类、物体检测等任务。ResNet 的训练可视化可以帮助我们更好地理解模型的优化过程,以及如何进一步改进模型的性能。下面是一些使用 PyTorch 和 TensorBoard 可视化 ResNet 训练的示例代码:
```python
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
# 定义数据预处理
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载 CIFAR10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
# 加载 ResNet18 模型
net = torchvision.models.resnet18(pretrained=False, num_classes=10)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# 创建 TensorBoard SummaryWriter 对象
writer = SummaryWriter()
# 训练循环
for epoch in range(200):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# 获取输入数据和标签
inputs, labels = data
# 将输入数据和标签转换为 GPU 上的 tensor
inputs, labels = inputs.cuda(), labels.cuda()
# 清除梯度
optimizer.zero_grad()
# 前向传播
outputs = net(inputs)
# 计算损失函数
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
# 更新模型参数
optimizer.step()
# 统计损失函数值
running_loss += loss.item()
if i % 100 == 99: # 每 100 个 batch 输出一次损失函数值
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
# 将损失函数值写入 TensorBoard
writer.add_scalar('training_loss', running_loss / 100, epoch * len(trainloader) + i)
running_loss = 0.0
print('Finished Training')
# 关闭 TensorBoard SummaryWriter 对象
writer.close()
```
在上面的代码中,我们首先定义了数据预处理、加载 CIFAR10 数据集和 ResNet18 模型。接着,我们定义了损失函数和优化器,并创建了一个 TensorBoard SummaryWriter 对象。在训练循环中,我们依次读取每个 batch 的数据,将输入数据和标签转换为 GPU 上的 tensor,然后进行前向传播、计算损失函数、反向传播和更新模型参数等操作。在每 100 个 batch 计算一次平均损失函数值并输出到控制台,同时将损失函数值写入 TensorBoard 以便后续可视化分析。当训练结束后,我们关闭 TensorBoard SummaryWriter 对象。
阅读全文