pytorch图像分类实战
时间: 2023-10-20 15:06:50 浏览: 105
1. 准备数据集
在图像分类实战中,我们通常需要先准备数据集。PyTorch中提供了torchvision模块,可以方便地加载和处理常见的图像数据集,如MNIST、CIFAR10、ImageNet等。
下面以CIFAR10数据集为例,介绍如何准备数据集。
1.1 下载数据集
可以从CIFAR官网下载CIFAR10数据集:https://www.cs.toronto.edu/~kriz/cifar.html
也可以直接使用PyTorch提供的数据集下载函数:
```
import torchvision.datasets as datasets
trainset = datasets.CIFAR10(root='./data', train=True, download=True)
testset = datasets.CIFAR10(root='./data', train=False, download=True)
```
其中,train=True表示下载训练集,train=False表示下载测试集。root参数指定数据集下载的目录。
1.2 加载数据集
PyTorch提供了DataLoader类,可以方便地对数据集进行批量处理和加载。
```
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
```
其中,transform参数指定数据预处理的方式。上面的代码将图像resize到224x224大小,并进行了标准化处理。batch_size参数指定每个batch的大小,shuffle参数指定是否随机打乱数据集,num_workers参数指定数据加载时使用的线程数。
2. 定义模型
接下来,我们需要定义一个模型来对图像进行分类。在PyTorch中,可以使用nn模块来定义模型。
```
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
```
上面的代码定义了一个简单的卷积神经网络,包含两个卷积层、两个池化层和三个全连接层。
3. 训练模型
有了数据集和模型,接下来就可以开始训练模型了。在PyTorch中,可以使用optim模块来定义优化器,使用nn模块中的损失函数来计算损失。
```
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
```
上面的代码定义了一个SGD优化器,学习率为0.001,动量为0.9。每个epoch内,遍历整个训练集,每次取出一个batch进行训练,计算损失并反向传播更新参数。
4. 测试模型
训练完成后,可以使用测试集来评估模型的性能。
```
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
```
上面的代码使用测试集来测试模型的准确率。
5. 保存和加载模型
训练完成后,可以将模型保存下来,以便在以后再次使用。
```
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
```
上面的代码将模型参数保存到文件cifar_net.pth中。
加载模型时,可以使用以下代码:
```
net = Net()
net.load_state_dict(torch.load(PATH))
```
上面的代码创建了一个新的模型,并将保存的参数加载到该模型中。
6. 总结
以上就是PyTorch图像分类实战的主要内容。通过以上实战,可以了解到PyTorch中如何准备数据集、定义模型、训练模型、测试模型以及保存和加载模型等基本操作。同时,还可以通过实战体验到PyTorch的简单易用、灵活性高等特点。
阅读全文