pytorch 猫狗 代码示例
时间: 2023-10-26 16:08:14 浏览: 107
基于pytorch的猫狗分类(源码+数据集)
5星 · 资源好评率100%
Sure! Here's a simple code example using PyTorch to classify images of cats and dogs:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# Define the neural network architecture
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc1 = nn.Linear(32 * 56 * 56, 128)
self.fc2 = nn.Linear(128, 2)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 56 * 56)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Set device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load and transform the dataset
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.ImageFolder(root='train_data_path', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
# Define the network, loss function, and optimizer
net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Training loop
for epoch in range(5):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
print("Training finished!")
# Save the trained model
torch.save(net.state_dict(), 'model.pth')
```
Note:
1. Replace `'train_data_path'` with the actual path to your training dataset folder, containing separate subfolders for cats and dogs.
2. Make sure you have the necessary dependencies installed (e.g., `torch`, `torchvision`).
This code defines a CNN (Convolutional Neural Network) architecture using PyTorch for classifying images of cats and dogs. It loads and preprocesses the dataset, trains the network, and saves the trained model. You can modify the network architecture, hyperparameters, and other aspects to fit your specific requirements.
阅读全文