给出模型断点恢复训练的pytorch代码,要求详细注释
时间: 2024-03-01 11:56:11 浏览: 159
好的,下面是一个简单的 PyTorch 代码示例,演示了如何使用 `torch.save()` 和 `torch.load()` 函数实现模型断点恢复训练。代码中使用的是 ResNet18 模型和 CIFAR10 数据集。
```python
import torch
import torchvision
import torchvision.transforms as transforms
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义超参数
num_epochs = 10
batch_size = 128
learning_rate = 0.001
# 加载 CIFAR10 数据集
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
# 定义 ResNet18 模型
model = torchvision.models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 10)
model = model.to(device)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
# 定义模型保存路径和文件名
checkpoint_path = './checkpoint.pth'
# 加载之前保存的模型参数
try:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
print("Successfully loaded checkpoint '{}' (epoch {})".format(checkpoint_path, start_epoch))
except:
print("No checkpoint found at '{}'".format(checkpoint_path))
start_epoch = 0
# 训练模型
for epoch in range(start_epoch, num_epochs):
for i, (images, labels) in enumerate(trainloader):
images = images.to(device)
labels = labels.to(device)
# 前向传播和计算损失
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, len(trainloader), loss.item()))
# 保存模型参数
torch.save({
'epoch': epoch+1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, checkpoint_path)
print('Finished Training')
```
在这个代码示例中,我们首先加载 CIFAR10 数据集,并定义 ResNet18 模型、损失函数和优化器。然后我们定义了一个模型保存路径和文件名,并尝试加载之前保存的模型参数。如果成功加载,就从加载的 epoch 开始训练,否则从第 0 个 epoch 开始训练。在训练过程中,每训练完一个 epoch 就保存一次模型参数,以便恢复训练时可以从最近一次保存的参数开始训练。最后输出训练完成信息。
阅读全文