resnet50代码 pytorch
时间: 2023-07-28 22:11:08 浏览: 108
resnet pytorch代码
当然,下面是一个使用PyTorch实现的ResNet-50模型的代码示例:
```python
import torch
import torch.nn as nn
import torchvision.models as models
# 加载预训练的ResNet-50模型
resnet = models.resnet50(pretrained=True)
# 替换最后一层全连接层
num_features = resnet.fc.in_features
resnet.fc = nn.Linear(num_features, num_classes)
# 将模型转移到GPU上(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet = resnet.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(resnet.parameters(), lr=0.001, momentum=0.9)
# 训练模型
for epoch in range(num_epochs):
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
# 前向传播
outputs = resnet(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印训练信息
if (i+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], Loss: {loss.item():.4f}')
```
上面的代码使用了PyTorch中的`torchvision.models`模块来加载预训练的ResNet-50模型,并替换了最后一层全连接层以适应特定的分类任务。然后,将模型移动到GPU上(如果可用),定义损失函数和优化器,并通过迭代训练数据进行模型训练。
请注意,这只是一个基本的代码示例,实际使用时可能需要根据具体任务进行调整,并根据数据集准备数据加载器和其他必要的步骤。
阅读全文