resnet34进行4个类型进行训练并分类,并打印混淆矩阵pytorch代码
时间: 2023-05-16 13:04:26 浏览: 222
以下是使用 PyTorch 训练 ResNet34 进行 4 类别分类并打印混淆矩阵的代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix
# Define transforms for training and testing data
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
# Define ResNet34 model
class ResNet34(nn.Module):
def __init__(self, num_classes=4):
super(ResNet34, self).__init__()
self.resnet = torchvision.models.resnet34(pretrained=True)
self.resnet.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.resnet(x)
return x
# Initialize model, loss function, and optimizer
net = ResNet34()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# Train the model
for epoch in range(50):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
# Test the model and print confusion matrix
net.eval()
y_true = []
y_pred = []
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
y_true += labels.tolist()
y_pred += predicted.tolist()
print(confusion_matrix(y_true, y_pred))
```
请注意,这个代码只是一个示例,如果你要训练不同的数据集或分类更多的类别,你需要相应地更改代码。
阅读全文