pytorch实现10折交叉验证
时间: 2023-11-05 08:59:44 浏览: 167
十折交叉验证pytorch搭建整个模型的代码.rar
5星 · 资源好评率100%
10折交叉验证是一种常用的评估模型性能的方法,下面是使用 PyTorch 实现的示例代码:
```python
import torch
import numpy as np
from sklearn.model_selection import KFold
# 定义数据集和模型
X = torch.randn(100, 10)
y = torch.randn(100, 1)
model = torch.nn.Linear(10, 1)
# 定义损失函数和优化器
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 定义 10 折交叉验证
kfold = KFold(n_splits=10, shuffle=True)
# 进行交叉验证
for fold, (train_idx, val_idx) in enumerate(kfold.split(X)):
print(f"Fold [{fold+1}/10]")
# 划分数据集
X_train, y_train = X[train_idx], y[train_idx]
X_val, y_val = X[val_idx], y[val_idx]
# 训练模型
for epoch in range(100):
y_pred = model(X_train)
loss = criterion(y_pred, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 在验证集上评估模型性能
with torch.no_grad():
y_pred = model(X_val)
val_loss = criterion(y_pred, y_val)
print(f"Validation Loss: {val_loss:.4f}")
```
以上代码中,我们使用了 PyTorch 自带的 `torch.randn` 函数生成了一个随机的数据集。我们在模型训练时,先使用 `KFold` 函数将数据集划分为 10 个子集,然后在每个子集上进行训练和验证。在每个子集上训练结束后,我们使用 `torch.no_grad()` 包装上下文,对模型在验证集上的性能进行评估。
阅读全文