pytorch 网格搜索代码
时间: 2023-09-12 16:10:00 浏览: 44
以下是使用 PyTorch 进行网格搜索的示例代码:
```python
import itertools
import torch
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm
# 定义模型
class Net(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# 定义超参数
input_dim = 784
output_dim = 10
lr = [0.01, 0.001]
batch_size = [64, 128]
hidden_dim = [128, 256]
# 定义数据集
dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
train_data, val_data = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
# 网格搜索
best_acc = 0
for lr_, batch_size_, hidden_dim_ in itertools.product(lr, batch_size, hidden_dim):
model = Net(input_dim, hidden_dim_, output_dim)
optimizer = Adam(model.parameters(), lr=lr_)
for epoch in range(10):
# 训练
model.train()
for images, labels in tqdm(train_loader):
optimizer.zero_grad()
images = images.view(-1, input_dim)
outputs = model(images)
loss = F.nll_loss(outputs, labels)
loss.backward()
optimizer.step()
# 验证
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in tqdm(val_loader):
images = images.view(-1, input_dim)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct / total
# 更新最佳模型
if acc > best_acc:
best_acc = acc
best_model = model
best_hyperparams = {'lr': lr_, 'batch_size': batch_size_, 'hidden_dim': hidden_dim_}
```
在该代码中,我们使用 itertools.product 函数对超参数进行组合,然后在每个超参数组合下进行模型训练和验证,最后选择具有最高验证精度的模型作为最佳模型。可以根据需要更改超参数的值和组合。