pytorch中图片二分类交叉验证代码
时间: 2024-12-15 13:24:15 浏览: 2
在PyTorch中,进行图像二分类任务的交叉验证通常包括以下几个步骤:
```python
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms
from torch.nn import functional as F
from torch import nn, optim
# 数据预处理
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]) # 根据数据集调整归一化参数
])
# 加载数据集
dataset = datasets.ImageFolder(root='path_to_your_dataset', transform=transform)
# 分割训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 创建DataLoader
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# 定义模型(这里以简单的卷积神经网络为例)
class TwoClassClassifier(nn.Module):
def __init__(self):
super(TwoClassClassifier, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16*16, 64)
self.fc2 = nn.Linear(64, 2) # 二分类,所以最后一层有2个节点
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 16*16)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=1)
model = TwoClassClassifier()
criterion = nn.CrossEntropyLoss() # 二分类交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器
# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
model.train()
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
total_val_loss = 0
correct_count = 0
for images, labels in val_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total_val_loss += criterion(outputs, labels).item()
correct_count += (predicted == labels).sum().item()
avg_val_loss = total_val_loss / len(val_loader)
accuracy = correct_count / len(val_dataset)
print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy*100:.2f}%")
阅读全文