L2正则化pytorch代码
时间: 2023-10-27 09:07:00 浏览: 71
L2正则化是通过在目标函数中加入一个正则项来惩罚模型的复杂度,从而防止过拟合。在PyTorch中,可以通过在优化器中设置weight_decay参数来实现L2正则化。具体的代码如下所示:
```
import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 定义数据和优化器
inputs = torch.randn(32, 10)
labels = torch.randn(32, 1)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)
# 训练模型
for epoch in range(100):
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.MSELoss()(outputs, labels)
loss.backward()
optimizer.step()
```
在上面的代码中,我们定义了一个简单的神经网络模型,并使用SGD优化器进行训练。在定义优化器时,我们设置了weight_decay参数为0.001,这就是L2正则化的实现方式。
阅读全文