写一个用pytorch,VGG16分析JAFFE数据的代码,划分训练集,验证集和测试集,给出准确率图,损失图和混淆矩阵
时间: 2023-06-10 15:03:11 浏览: 113
好的,以下是您所需的代码。请确保您已将JAFFE数据集的路径设置为正确的位置,并确保已正确安装PyTorch和NumPy库。
```python
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# 数据集路径
data_path = "path/to/dataset"
# 加载数据
transform = torchvision.transforms.Compose([
torchvision.transforms.Grayscale(num_output_channels=1),
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5])])
dataset = torchvision.datasets.ImageFolder(root=data_path, transform=transform)
n_classes = len(dataset.classes)
# 划分数据集
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
# 加载数据集
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# 加载模型
model = torchvision.models.vgg16(pretrained=True)
model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=n_classes, bias=True)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# 训练模型
n_epochs = 20
train_loss = []
val_loss = []
train_acc = []
val_acc = []
for epoch in range(n_epochs):
running_loss = 0.0
correct_train = 0
correct_val = 0
# 训练集
for i, (inputs, labels) in enumerate(train_loader, 0):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
correct_train += (predicted == labels).sum().item()
train_acc.append(correct_train / len(train_dataset))
train_loss.append(running_loss / len(train_loader))
# 验证集
with torch.no_grad():
running_loss = 0.0
for inputs, labels in val_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
correct_val += (predicted == labels).sum().item()
val_acc.append(correct_val / len(val_dataset))
val_loss.append(running_loss / len(val_loader))
print('[Epoch %d] train_loss: %.3f | train_acc: %.3f | val_loss: %.3f | val_acc: %.3f' %
(epoch + 1, train_loss[-1], train_acc[-1], val_loss[-1], val_acc[-1]))
# 测试模型
model.eval()
test_true = []
test_pred = []
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
test_true.extend(labels.numpy())
test_pred.extend(predicted.numpy())
# 计算准确率和混淆矩阵
test_acc = (np.array(test_true) == np.array(test_pred)).mean()
conf_matrix = confusion_matrix(test_true, test_pred)
conf_matrix = conf_matrix / conf_matrix.sum(axis=1, keepdims=True)
# 绘制准确率和损失图
plt.plot(train_loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
plt.plot(train_acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Accuracy Over Time')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
# 绘制混淆矩阵
plt.imshow(conf_matrix, cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.xticks(np.arange(n_classes), dataset.classes, rotation=90)
plt.yticks(np.arange(n_classes), dataset.classes)
plt.colorbar()
plt.show()
```
希望代码能够帮到您!
阅读全文