resnet34进行4个类型进行训练并分类,并打印混淆矩阵pytorch代码
时间: 2023-05-14 12:07:23 浏览: 106
以下是使用 PyTorch 训练 ResNet34 进行 4 类别分类并打印混淆矩阵的代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import confusion_matrix
import numpy as np
# Define transforms for training and testing data
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
# Define ResNet34 model
class ResNet34(nn.Module):
def __init__(self, num_classes=4):
super(ResNet34, self).__init__()
self.resnet = torchvision.models.resnet34(pretrained=True)
self.resnet.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.resnet(x)
return x
# Define device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Initialize model and optimizer
net = ResNet34().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Train the model
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
# Test the model
net.eval()
y_true = []
y_pred = []
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
y_true += labels.cpu().numpy().tolist()
y_pred += predicted.cpu().numpy().tolist()
# Print confusion matrix
conf_mat = confusion_matrix(y_true, y_pred)
print(conf_mat)
```
请注意,这个代码只是一个示例,你需要根据你的数据集和分类任务进行修改。
阅读全文