pytorch中SGD使用学习率指数衰减代码
时间: 2024-10-25 14:03:13 浏览: 41
在PyTorch中,如果你想要使用学习率指数衰减策略,可以使用`torch.optim.lr_scheduler`模块提供的`ExponentialLR`函数。这是一个简单例子:
```python
import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR
# 创建一个模型、优化器和学习率衰减器
model = ... # 你的模型实例
optimizer = SGD(model.parameters(), lr=0.1) # 初始化SGD优化器,初始学习率为0.1
scheduler = ExponentialLR(optimizer, gamma=0.95) # gamma设置为指数衰减因子,默认值为0.999
# 训练循环
for epoch in range(num_epochs):
scheduler.step() # 在每个epoch开始前调整学习率
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
```
在这个例子中,`gamma`参数决定了每次衰减的速度。每经过一个epoch,学习率会被乘以`gamma`。如果你想自定义衰减周期或其他细节,你还可以通过`ExponentialLR`的构造函数传入相应的参数。
阅读全文