请帮我写一个基于ResNet18实现猫狗分类的程序
时间: 2024-05-05 11:21:06 浏览: 126
好的,以下是一个基于ResNet18实现猫狗分类的程序,使用PyTorch框架:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 定义数据预处理
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载数据集
trainset = torchvision.datasets.ImageFolder(root='./train', transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testset = torchvision.datasets.ImageFolder(root='./test', transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
# 定义ResNet18模型
class ResNet18(nn.Module):
def __init__(self, num_classes=2):
super(ResNet18, self).__init__()
self.resnet = torchvision.models.resnet18(pretrained=True)
self.resnet.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True)
def forward(self, x):
x = self.resnet(x)
return x
# 定义模型、损失函数、优化器
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练模型
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/100))
running_loss = 0.0
print('Finished Training')
# 测试模型
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
```
请注意,此程序假定您已经下载了猫狗数据集并将其分别放在train和test文件夹中。另外,此程序使用预训练的ResNet18模型,因此需要下载预训练权重。
阅读全文