pytorch 线性回归 例子
时间: 2024-09-15 13:00:51 浏览: 52
PyTorch是一个广泛使用的深度学习框架,线性回归是其中最基础的模型之一,它用于预测连续数值输出。下面是一个简单的线性回归例子,我们将使用PyTorch创建一个线性模型并训练它。
首先,你需要安装PyTorch库,如果还没有安装,可以使用pip安装:
```bash
pip install torch
```
然后,我们可以编写如下的Python代码来实现线性回归:
```python
import torch
import torch.nn as nn
# 定义模型
class LinearRegression(nn.Module):
def __init__(self, input_dim):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_dim, 1) # 一维输出,因为我们预测的是单变量的值
def forward(self, x):
return self.linear(x)
# 数据准备(这里假设我们有一个二维输入的数据集)
x_data = torch.randn(100, 1)
y_data = 2 * x_data + 1 # 生成线性关系的目标数据,真实y = 2x + 1
# 将数据转化为张量,并设置为需要训练模式
x_data = x_data.type(torch.FloatTensor).requires_grad_()
y_data = y_data.type(torch.FloatTensor)
# 初始化模型参数并设置损失函数和优化器
model = LinearRegression(1)
criterion = nn.MSELoss() # 用于均方误差损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 使用随机梯度下降优化器
# 训练循环
num_epochs = 1000
for epoch in range(num_epochs):
predictions = model(x_data)
loss = criterion(predictions, y_data)
optimizer.zero_grad() # 清除梯度
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新模型参数
# 打印训练结果
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
阅读全文