生成二层神经网络超参数调优pytorch代码
时间: 2023-08-24 08:09:50 浏览: 117
pytorch自定义二值化网络层方式
可以直接回答,以下是示例代码:
``` python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.optim.lr_scheduler import StepLR
# Define neural network architecture
class Net(nn.Module):
def __init__(self, num_hidden=32):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, num_hidden)
self.fc2 = nn.Linear(num_hidden, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Set hyperparameters
num_epochs = 20
batch_size = 64
learning_rate = 0.1
num_hidden_choices = [8, 16, 32, 64, 128]
# Load the data
train_data = MNIST(root='./data', train=True, transform=ToTensor(), download=True)
test_data = MNIST(root='./data', train=False, transform=ToTensor(), download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
# Grid search for optimal hyperparameters
best_acc = 0
for num_hidden in num_hidden_choices:
# Initialize neural network
net = Net(num_hidden)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
# Train the neural network
for epoch in range(num_epochs):
net.train()
for images, labels in train_loader:
optimizer.zero_grad()
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Test the neural network
net.eval()
total = 0
correct = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = correct / total
# Adjust learning rate
scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
scheduler.step()
# Update best accuracy
if acc > best_acc:
best_acc = acc
best_num_hidden = num_hidden
print("Best accuracy:", best_acc)
print("Best number of hidden neurons:", best_num_hidden)
```
阅读全文