resnet50训练部分代码
时间: 2024-04-11 19:25:27 浏览: 142
ResNet-50是一种深度卷积神经网络模型,常用于图像分类任务。下面是一个简单的ResNet-50训练部分的代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
# 加载预训练的ResNet-50模型
model = models.resnet50(pretrained=True)
# 替换最后一层全连接层,适应分类任务的输出类别数
num_classes = 10 # 假设有10个类别
model.fc = nn.Linear(model.fc.in_features, num_classes)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 加载训练数据和标签
train_loader = ...
labels = ...
# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
running_loss = 0.0
for images, labels in train_loader:
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
# 打印每个epoch的损失值
print(f"Epoch {epoch+1} loss: {running_loss / len(train_loader)}")
# 保存训练好的模型
torch.save(model.state_dict(), "resnet50_model.pth")
```
这段代码首先加载了预训练的ResNet-50模型,并替换了最后一层全连接层以适应具体的分类任务。然后定义了损失函数和优化器。接下来,通过迭代训练数据和标签进行模型训练,使用交叉熵损失函数计算损失,并使用随机梯度下降(SGD)优化器更新模型参数。最后,保存训练好的模型参数。
阅读全文