pytorch线性回归
时间: 2023-10-18 18:14:46 浏览: 42
PyTorch是一个基于Python的深度学习框架,它支持动态计算图和静态计算图,能够快速构建和训练各种深度学习模型。其中,线性回归是最简单的机器学习算法之一,用于预测一个连续值的输出。
下面是一个使用PyTorch实现线性回归的示例代码:
```python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 生成数据集
np.random.seed(1)
x = np.random.rand(100, 1)
y = 5 + 2 * x + np.random.rand(100, 1)
# 转换为张量
inputs = torch.from_numpy(x).float()
targets = torch.from_numpy(y).float()
# 定义模型
class LinearRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
input_size = 1
output_size = 1
model = LinearRegression(input_size, output_size)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
num_epochs = 1000
losses = []
for epoch in range(num_epochs):
# 前向传播和计算损失
outputs = model(inputs)
loss = criterion(outputs, targets)
losses.append(loss.item())
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印损失
if (epoch+1) % 50 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# 可视化损失
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
# 测试模型
with torch.no_grad():
predicted = model(inputs).detach().numpy()
# 可视化预测结果
plt.plot(x, y, 'ro', label='Original data')
plt.plot(x, predicted, label='Fitted line')
plt.legend()
plt.show()
```
这段代码首先生成了一个简单的数据集,然后将数据转换为PyTorch张量,定义了一个线性回归模型,并将损失函数设为均方误差(MSE)损失函数。然后使用随机梯度下降(SGD)优化器来训练模型,迭代1000次,每50次打印一次损失,并可视化损失曲线。最后,使用训练好的模型对输入数据进行预测,并可视化预测结果。
这个简单的示例演示了如何使用PyTorch构建和训练一个线性回归模型,并对预测结果进行可视化。实际上,PyTorch还提供了许多其他的深度学习算法和工具,可以帮助你构建和训练更复杂的深度学习模型。