resnet18多分类输出代码
时间: 2024-03-04 16:46:41 浏览: 129
多分类代码
当使用ResNet-18进行多分类任务时,我们需要对网络的最后一层进行修改,以适应分类的类别数量。以下是一个示例代码,展示了如何使用PyTorch实现ResNet-18的多分类输出:
```python
import torch
import torch.nn as nn
import torchvision.models as models
# 加载预训练的ResNet-18模型
resnet = models.resnet18(pretrained=True)
# 冻结所有参数
for param in resnet.parameters():
param.requires_grad = False
# 获取ResNet-18的最后一层输入特征数量
num_features = resnet.fc.in_features
# 替换最后一层全连接层
resnet.fc = nn.Linear(num_features, num_classes)
# 将模型设置为训练模式
resnet.train()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(resnet.parameters(), lr=0.001, momentum=0.9)
# 训练过程
for epoch in range(num_epochs):
# 前向传播
outputs = resnet(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印训练信息
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
# 测试过程
resnet.eval()
with torch.no_grad():
outputs = resnet(test_inputs)
_, predicted = torch.max(outputs.data, 1)
accuracy = (predicted == test_labels).sum().item() / len(test_labels)
print(f'Test Accuracy: {accuracy}')
```
在上述代码中,我们首先加载预训练的ResNet-18模型,并冻结所有参数。然后,我们通过替换最后一层全连接层来适应多分类任务的类别数量。接下来,我们定义损失函数和优化器,并进行训练和测试过程。
阅读全文