使用PyTorch对CIFAR-10数据集处理并绘图,导出模型示例
时间: 2024-03-02 22:50:24 浏览: 121
好的,以下是一个使用PyTorch对CIFAR-10数据集进行处理、绘图并导出模型的示例:
```python
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# 定义一些数据处理的transform
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
# 定义类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 显示一些训练图片
def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
# 定义一个CNN模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.bn1 = torch.nn.BatchNorm2d(64)
self.relu1 = torch.nn.ReLU(inplace=True)
self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn2 = torch.nn.BatchNorm2d(128)
self.relu2 = torch.nn.ReLU(inplace=True)
self.conv3 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.bn3 = torch.nn.BatchNorm2d(256)
self.relu3 = torch.nn.ReLU(inplace=True)
self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = torch.nn.Linear(256 * 8 * 8, 512)
self.relu4 = torch.nn.ReLU(inplace=True)
self.fc2 = torch.nn.Linear(512, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.pool(x)
x = x.view(-1, 256 * 8 * 8)
x = self.fc1(x)
x = self.relu4(x)
x = self.fc2(x)
return x
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# 训练模型
net = Net()
for epoch in range(100):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('[%d] loss: %.3f' % (epoch + 1, running_loss / len(trainloader)))
# 保存模型
torch.save(net.state_dict(), 'cifar_net.pth')
```
这个示例中,我们首先定义了一些数据处理的transform,然后使用`torchvision.datasets.CIFAR10`加载CIFAR-10数据集,并使用`torch.utils.data.DataLoader`生成数据迭代器。接着,我们定义了一个简单的CNN模型,并使用交叉熵损失函数和随机梯度下降优化器进行训练。最后,我们保存了训练好的模型。在代码中,我们还显示了一些训练图片,并使用了一个名为`imshow`的函数来实现图片的展示。这个示例中的一些处理步骤包括数据增强、归一化、批量处理等,这些都是为了提高模型的训练效果。
阅读全文