torchvision怎么使用预训练模型?
时间: 2024-05-02 13:03:49 浏览: 128
Torchvision是PyTorch的一个包,它提供了一些常用的计算机视觉模型以及常用的数据集。在Torchvision中使用预训练模型非常简单,只需要加载预训练模型的权重即可。
以下是一个使用预训练的ResNet-18模型对CIFAR-10数据集进行分类的示例代码:
```python
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# 加载预训练的ResNet-18模型
model = models.resnet18(pretrained=True)
# 将最后一层全连接层替换为10分类的全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
# 定义数据预处理步骤
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练模型
for epoch in range(10):
for i, (images, labels) in enumerate(train_dataset):
# 前向传播
outputs = model(images.unsqueeze(0))
loss = criterion(outputs, labels.unsqueeze(0))
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 在测试集上测试模型的准确率
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_dataset:
outputs = model(images.unsqueeze(0))
_, predicted = torch.max(outputs.data, 1)
total += 1
correct += (predicted == labels).sum().item()
print('Accuracy of the model on the test images: %d %%' % (100 * correct / total))
```
在上述代码中,我们首先加载预训练的ResNet-18模型,并将其最后一层全连接层替换为10分类的全连接层。然后,我们定义了数据预处理步骤并加载CIFAR-10数据集。接着,我们定义了损失函数和优化器,并进行了模型的训练。最后,我们在测试集上测试了模型的准确率。
阅读全文
相关推荐
















