l2正则化代码pytorch
时间: 2023-08-30 15:09:40 浏览: 100
PyTorch中的正则化:提升模型性能的秘诀
以下是在PyTorch中实现L2正则化的示例代码:
```python
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear1 = nn.Linear(10, 5)
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
x = self.linear1(x)
x = nn.functional.relu(x)
x = self.linear2(x)
return x
model = MyModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.1)
```
在这个示例中,我们定义了一个包含两个线性层的神经网络模型。在定义模型时,我们没有显式地添加L2正则化项,而是在定义优化器时使用了`weight_decay`参数。这个参数告诉优化器在更新模型权重时,将权重乘以一个小于1的因子,从而实现L2正则化。在这个示例中,`weight_decay`的值为0.1,表示L2正则化的强度为0.1。
在训练模型时,我们可以像通常一样使用`optimizer.step()`来更新模型,优化器会自动应用L2正则化。
阅读全文