pytorch 猫狗 代码示例
时间: 2023-10-26 13:08:14 浏览: 61
Sure! Here's a simple code example using PyTorch to classify images of cats and dogs:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# Define the neural network architecture
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc1 = nn.Linear(32 * 56 * 56, 128)
self.fc2 = nn.Linear(128, 2)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 56 * 56)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Set device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load and transform the dataset
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.ImageFolder(root='train_data_path', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
# Define the network, loss function, and optimizer
net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Training loop
for epoch in range(5):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
print("Training finished!")
# Save the trained model
torch.save(net.state_dict(), 'model.pth')
```
Note:
1. Replace `'train_data_path'` with the actual path to your training dataset folder, containing separate subfolders for cats and dogs.
2. Make sure you have the necessary dependencies installed (e.g., `torch`, `torchvision`).
This code defines a CNN (Convolutional Neural Network) architecture using PyTorch for classifying images of cats and dogs. It loads and preprocesses the dataset, trains the network, and saves the trained model. You can modify the network architecture, hyperparameters, and other aspects to fit your specific requirements.