pytorch resnet18分类器 代码
时间: 2023-11-03 11:05:14 浏览: 115
你可以使用以下代码实现PyTorch中的resnet18分类器:
```python
import torch
import torchvision.models as models
# 加载预训练的resnet18模型
model = models.resnet18(pretrained=True)
# 替换全连接层
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, num_classes)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练和测试模型
for epoch in range(num_epochs):
# 训练模型
model.train()
for images, labels in train_loader:
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试模型
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print('Accuracy on test set: {:.2f}%'.format(accuracy))
```
阅读全文