猫狗识别pytorch代码
时间: 2024-12-28 21:36:14 浏览: 5
在PyTorch中进行猫狗识别通常会使用深度学习框架,如卷积神经网络(Convolutional Neural Networks, CNNs)。以下是一个简单的示例,展示如何使用ResNet等预训练模型进行图像分类任务:
```python
import torch
from torchvision import datasets, models, transforms
# 数据预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder('path_to_dogs_vs_cats/train', transform=transform)
test_dataset = datasets.ImageFolder('path_to_dogs_vs_cats/test', transform=transform)
# 数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_loader, batch_size=32, shuffle=False)
# 使用预训练的ResNet模型
model = models.resnet50(pretrained=True)
num_classes = len(train_dataset.classes) # 确定类别数,这里是猫和狗
model.fc = torch.nn.Linear(model.fc.in_features, num_classes) # 修改最后一层全连接层
# 模型训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练过程...
# ...
```
请注意,这只是一个基础版本的代码,实际应用中可能还需要数据增强、验证集验证、训练循环以及模型评估等步骤。
阅读全文