pytorch实现L2正则
时间: 2023-10-23 11:29:02 浏览: 87
在 PyTorch 中,可以通过在模型定义中的参数上添加 L2 正则化的权重来实现 L2 正则化。具体来说,可以使用 `torch.nn.Module` 类的 `register_parameter()` 方法来注册模型参数并指定它们的正则化权重。
例如,假设我们有一个线性层,我们想要对其权重进行 L2 正则化,可以在模型定义中添加如下代码:
```python
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_size, output_size):
super(MyModel, self).__init__()
self.linear = nn.Linear(input_size, output_size)
self.register_parameter("weight_decay", nn.Parameter(torch.tensor(1e-5)))
# 注册正则化权重
def forward(self, x):
out = self.linear(x)
return out
def l2_regularization(self):
return self.weight_decay * torch.norm(self.linear.weight)
# 计算 L2 正则化项
model = MyModel(10, 5)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
# 在训练循环中,计算损失函数时加上正则化项
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target) + model.l2_regularization()
loss.backward()
optimizer.step()
```
在上面的代码中,我们使用 `register_parameter()` 方法注册了一个名为 `weight_decay` 的正则化权重,并在模型的 `forward()` 方法中计算了 L2 正则化项。在训练循环中,我们使用 `model.l2_regularization()` 方法计算正则化项并将其加到损失函数中。最后,使用反向传播和优化器来更新模型的权重和偏置。
需要注意的是,这里的 `torch.norm()` 函数默认计算 L2 范数。如果要使用其他范数,请参考 PyTorch 文档中的说明。
阅读全文