举一个使用参数组迭代器的实例
时间: 2024-04-12 19:30:04 浏览: 64
当你有多个参数组,每个参数组具有不同的学习率和权重衰减时,可以使用参数组迭代器来灵活地定义不同的优化策略。以下是一个使用参数组迭代器的示例:
```python
import torch
import torch.optim as optim
# 定义一个模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = torch.nn.Linear(10, 5)
self.fc2 = torch.nn.Linear(5, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = MyModel()
# 定义不同参数组的学习率和权重衰减
params = [
{'params': model.fc1.parameters(), 'lr': 0.01, 'weight_decay': 0.001},
{'params': model.fc2.parameters(), 'lr': 0.001, 'weight_decay': 0.0001}
]
# 使用参数组迭代器创建优化器
optimizer = optim.SGD(params, lr=0.1)
# 打印每个参数组的学习率和权重衰减
for i, param_group in enumerate(optimizer.param_groups):
print(f"参数组 {i+1}: 学习率={param_group['lr']}, 权重衰减={param_group['weight_decay']}")
```
在这个示例中,我们定义了一个模型 `MyModel`,它有两个线性层 `fc1` 和 `fc2`。我们使用一个参数组迭代器 `params` 来创建优化器 `optimizer`。其中,`params` 是一个包含两个参数组的列表,每个参数组都分别指定了学习率和权重衰减的值。最后,我们遍历优化器的参数组,并打印每个参数组的学习率和权重衰减。
请根据你的具体情况替换示例代码中的模型、参数组和优化器类型,并确保参数组迭代器的格式正确。