pytorch实现线性回归
时间: 2023-07-18 09:07:03 浏览: 116
首先,我们需要导入 PyTorch 包:
```
import torch
```
然后,我们定义输入数据和目标输出:
```
# 输入数据
x_data = torch.Tensor([[1.0], [2.0], [3.0], [4.0]])
# 目标输出
y_data = torch.Tensor([[2.0], [4.0], [6.0], [8.0]])
```
接下来,我们定义模型。这里我们使用 PyTorch 中的 `nn.Module` 类来定义一个线性回归模型:
```
class LinearRegressionModel(torch.nn.Module):
def __init__(self):
super(LinearRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
y_pred = self.linear(x)
return y_pred
```
在模型中,我们定义了一个包含一个线性层的神经网络。该层接收一个输入特征,并输出一个标量结果。
接下来,我们创建一个模型实例,并定义损失函数和优化器:
```
# 创建模型实例
model = LinearRegressionModel()
# 定义损失函数
criterion = torch.nn.MSELoss()
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
```
在训练过程中,我们需要迭代多次,每次都计算模型的输出、损失和梯度,并更新模型参数:
```
# 迭代训练模型
for epoch in range(1000):
# 前向传播
y_pred = model(x_data)
# 计算损失
loss = criterion(y_pred, y_data)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 更新模型参数
optimizer.step()
# 每 100 次迭代输出一次损失
if (epoch+1) % 100 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 1000, loss.item()))
```
在训练完成后,我们可以使用模型进行预测:
```
# 预测结果
y_pred = model(torch.Tensor([[5.0]]))
print('预测结果: {:.2f}'.format(y_pred.item()))
```
完整代码如下:
```
import torch
# 输入数据
x_data = torch.Tensor([[1.0], [2.0], [3.0], [4.0]])
# 目标输出
y_data = torch.Tensor([[2.0], [4.0], [6.0], [8.0]])
class LinearRegressionModel(torch.nn.Module):
def __init__(self):
super(LinearRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
y_pred = self.linear(x)
return y_pred
# 创建模型实例
model = LinearRegressionModel()
# 定义损失函数
criterion = torch.nn.MSELoss()
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 迭代训练模型
for epoch in range(1000):
# 前向传播
y_pred = model(x_data)
# 计算损失
loss = criterion(y_pred, y_data)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 更新模型参数
optimizer.step()
# 每 100 次迭代输出一次损失
if (epoch+1) % 100 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 1000, loss.item()))
# 预测结果
y_pred = model(torch.Tensor([[5.0]]))
print('预测结果: {:.2f}'.format(y_pred.item()))
```
阅读全文