pytorch 实现k折交叉验证的完整代码
时间: 2023-11-20 20:14:19 浏览: 57
Sure! Here's an example of how you can implement k-fold cross-validation using PyTorch:
```python
import torch
from torch.utils.data import Subset, DataLoader
# Define your dataset and model
dataset = YourDataset()
model = YourModel()
# Set hyperparameters
num_epochs = 10
batch_size = 64
k_folds = 5
# Calculate the size of each fold
fold_size = len(dataset) // k_folds
# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Perform k-fold cross-validation
for fold in range(k_folds):
# Create the training and validation sets for this fold
start_index = fold * fold_size
end_index = (fold + 1) * fold_size
# Prepare the data loaders for training and validation
train_indices = list(set(range(len(dataset))) - set(range(start_index, end_index)))
train_sampler = SubsetRandomSampler(train_indices)
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
val_indices = list(range(start_index, end_index))
val_sampler = SubsetRandomSampler(val_indices)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)
# Train the model for this fold
for epoch in range(num_epochs):
model.train()
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# Evaluate the model on the validation set for this fold
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in val_loader:
outputs = model(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
# Print the validation accuracy for this fold
print(f"Fold {fold + 1}, Epoch {epoch + 1}: Validation Accuracy = {100 * correct / total}%")
```
In this code, you would need to replace `YourDataset` and `YourModel` with your own dataset and model classes. The code performs k-fold cross-validation by creating separate training and validation sets for each fold using the `SubsetRandomSampler`. The model is trained on the training set and evaluated on the validation set for each fold. Finally, the validation accuracy for each fold is printed.
Remember to adjust the hyperparameters (`num_epochs`, `batch_size`, `k_folds`, etc.) according to your specific needs.