给我写一个使用自己数据集的图像分类测试代码
时间: 2023-04-07 09:04:01 浏览: 78
好的,我可以回答这个问题。以下是一个使用自己数据集的图像分类测试代码的示例:
```python
import torch
import torchvision
from torchvision import transforms
# 定义数据集路径和转换
data_dir = 'path/to/your/dataset'
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载数据集
dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)
# 定义数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
# 冻结模型参数
for param in model.parameters():
param.requires_grad = False
# 替换最后一层全连接层
num_classes = len(dataset.classes)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 测试模型
model.eval()
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in dataloader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
print('Accuracy: {:.2f}%'.format(accuracy * 100))
```
这个代码使用了 PyTorch 框架,首先定义了数据集路径和转换,然后加载数据集和预训练模型,接着冻结模型参数并替换最后一层全连接层,定义损失函数和优化器,最后进行训练和测试。你可以根据自己的数据集和需求进行修改。