resnet18 cifar100 pytorch
时间: 2023-11-16 11:00:26 浏览: 167
ResNet18是一种深度卷积神经网络模型,适用于图像分类任务。CIFAR100是一个包含100个类别的图像分类数据集。在PyTorch中,可以使用ResNet18模型对CIFAR100数据集进行训练和测试。
使用ResNet18对CIFAR100数据集进行训练的参数和代码如下:
```
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
# 加载数据集
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
# 定义模型
class ResNet18(nn.Module):
def __init__(self, num_classes=100):
super(ResNet18, self).__init__()
self.resnet18 = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=False)
self.resnet18.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.resnet18(x)
return x
# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.002, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120], gamma=0.1)
num_epochs = 200
best_acc = 0.0
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
train_acc = 0.0
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
_, preds = torch.max(outputs, 1)
train_acc += torch.sum(preds == labels.data)
train_loss = train_loss / len(train_loader.dataset)
train_acc = train_acc / len(train_loader.dataset)
model.eval()
test_loss = 0.0
test_acc = 0.0
for i, (images, labels) in enumerate(test_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
test_loss += loss.item() * images.size(0)
_, preds = torch.max(outputs, 1)
test_acc += torch.sum(preds == labels.data)
test_loss = test_loss / len(test_loader.dataset)
test_acc = test_acc / len(test_loader.dataset)
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), './res/best_model.pth')
print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'.format(
epoch + 1, num_epochs, train_loss, train_acc, test_loss, test_acc))
scheduler.step()
# 测试模型
model.load_state_dict(torch.load('./res/best_model.pth'))
model.eval()
test_acc = 0.0
for i, (images, labels) in enumerate(test_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
test_acc += torch.sum(preds == labels.data)
test_acc = test_acc / len(test_loader.dataset)
print('Test Acc: {:.4f}'.format(test_acc))
```
阅读全文