L1正则化代码pytorch
时间: 2024-06-22 16:01:23 浏览: 85
在PyTorch中,L1正则化(也称为Lasso回归)通常是在模型的损失函数中添加一个权重衰减项,用来防止过拟合。这种正则化方法鼓励模型学习到稀疏的权重,即一些权重接近于0。以下是使用PyTorch实现L1正则化的简单示例:
```python
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_size, output_size):
super(MyModel, self).__init__()
# 在模型定义中添加参数
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
# 添加L1正则化
def l1_regularizer(self, weight_decay=0.01):
# weight_decay是L1惩罚的强度
return torch.sum(torch.abs(self.linear.weight))
# 实例化模型并训练
model = MyModel(input_size, output_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.01) # 注意weight_decay参数
loss_fn = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
# 训练步骤
optimizer.zero_grad()
outputs = model(x_train)
loss = loss_fn(outputs, y_train) + weight_decay * model.l1_regularizer()
loss.backward()
optimizer.step()
阅读全文