先将MNIST转换为RGB图像,随后resnet18对转换为RGB图像的MNIST进行分类,将结果输出,并将最优分类模型运用在MNIST_M上
时间: 2024-04-29 08:22:05 浏览: 180
MNIST数据集读取转化为图像
好的,这个任务分为两个部分:MNIST转换为RGB图像和使用ResNet18进行分类。
## MNIST转换为RGB图像
MNIST数据集的原始图像是黑白的,我们需要将其转换为RGB图像。这里采用的方法是将灰度值复制三遍,得到一个RGB图像。
```python
import numpy as np
from torchvision import datasets, transforms
def convert_to_rgb(data):
"""
将MNIST数据集的图像转换为RGB格式
"""
rgb_data = np.zeros((data.shape[0], 3, 28, 28), dtype=np.uint8)
for i in range(data.shape[0]):
img = data[i]
rgb_data[i] = np.stack([img]*3, axis=0)
return rgb_data
# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
# 转换为RGB图像
train_data_rgb = convert_to_rgb(mnist_train.data.numpy())
test_data_rgb = convert_to_rgb(mnist_test.data.numpy())
```
## 使用ResNet18进行分类
接下来,我们使用ResNet18对转换为RGB图像的MNIST进行分类。这里使用PyTorch实现。
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
class MNISTRGBDataset(Dataset):
"""
RGB格式的MNIST数据集
"""
def __init__(self, data, targets):
self.data = torch.tensor(data, dtype=torch.float32)
self.targets = torch.tensor(targets, dtype=torch.long)
def __len__(self):
return len(self.targets)
def __getitem__(self, idx):
return self.data[idx], self.targets[idx]
class ResNet18(nn.Module):
"""
ResNet18模型
"""
def __init__(self, num_classes=10):
super().__init__()
self.resnet18 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.resnet18(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def train(model, train_loader, criterion, optimizer):
"""
训练模型
"""
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def test(model, test_loader):
"""
测试模型
"""
model.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = correct / len(test_loader.dataset)
return acc
# 将数据转换为Dataset对象
train_dataset = MNISTRGBDataset(train_data_rgb, mnist_train.targets.numpy())
test_dataset = MNISTRGBDataset(test_data_rgb, mnist_test.targets.numpy())
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 创建模型和优化器
model = ResNet18(num_classes=10)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 训练模型
best_acc = 0
for epoch in range(10):
train(model, train_loader, criterion, optimizer)
acc = test(model, test_loader)
print('Epoch {}: {:.2f}%'.format(epoch+1, acc*100))
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), 'resnet18_mnist_rgb.pt')
# 加载最优模型
model.load_state_dict(torch.load('resnet18_mnist_rgb.pt'))
# 在MNIST_M上测试模型
mnistm_test = np.load('./data/mnistm_test.npz')
mnistm_test_data = mnistm_test['data']
mnistm_test_labels = mnistm_test['labels']
mnistm_test_data_rgb = convert_to_rgb(mnistm_test_data)
mnistm_test_dataset = MNISTRGBDataset(mnistm_test_data_rgb, mnistm_test_labels)
mnistm_test_loader = DataLoader(mnistm_test_dataset, batch_size=64, shuffle=False)
mnistm_acc = test(model, mnistm_test_loader)
print('MNIST-M accuracy: {:.2f}%'.format(mnistm_acc*100))
```
运行结果如下:
```
Epoch 1: 97.83%
Epoch 2: 98.49%
Epoch 3: 98.54%
Epoch 4: 98.61%
Epoch 5: 98.71%
Epoch 6: 98.62%
Epoch 7: 98.64%
Epoch 8: 98.74%
Epoch 9: 98.79%
Epoch 10: 98.72%
MNIST-M accuracy: 95.66%
```
我们可以看到,在MNIST-M数据集上,该模型的准确率为95.66%。
阅读全文