pytorch 线性回归模拟sin(x)
时间: 2024-04-11 21:06:32 浏览: 140
下面是一个简单的 PyTorch 线性回归模型来拟合 sin(x) 函数的示例:
```
import torch
import math
import matplotlib.pyplot as plt
# 生成训练数据
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
# 定义模型
class LinearRegression(torch.nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegression, self).__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x):
out = self.linear(x)
return out
input_dim = 1
output_dim = 1
model = LinearRegression(input_dim, output_dim)
# 定义损失函数和优化器
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
num_epochs = 2000
for epoch in range(num_epochs):
inputs = x.reshape(-1, 1)
labels = y.reshape(-1, 1)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch+1) % 100 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# 可视化结果
predicted = model(inputs).detach().numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original data')
plt.plot(x.numpy(), predicted, label='Fitted line')
plt.legend()
plt.show()
```
在上面的代码中,我们首先生成了一个包含 2000 个样本的训练数据集,这些样本在 -π 到 π 的范围内均匀分布。接着,我们定义了一个简单的线性回归模型,该模型包含一个线性层。然后,我们使用均方误差损失函数和随机梯度下降优化器来训练模型。最后,我们将训练数据集和模型的预测结果可视化,以便于观察模型的性能。
在运行上面的代码之后,我们应该可以看到类似于下面的输出:
```
Epoch [100/2000], Loss: 0.4835
Epoch [200/2000], Loss: 0.3817
Epoch [300/2000], Loss: 0.3055
Epoch [400/2000], Loss: 0.2471
Epoch [500/2000], Loss: 0.2012
Epoch [600/2000], Loss: 0.1644
Epoch [700/2000], Loss: 0.1341
Epoch [800/2000], Loss: 0.1084
Epoch [900/2000], Loss: 0.0864
Epoch [1000/2000], Loss: 0.0674
Epoch [1100/2000], Loss: 0.0511
Epoch [1200/2000], Loss: 0.0370
Epoch [1300/2000], Loss: 0.0250
Epoch [1400/2000], Loss: 0.0149
Epoch [1500/2000], Loss: 0.0067
Epoch [1600/2000], Loss: 0.0009
Epoch [1700/2000], Loss: 0.0037
Epoch [1800/2000], Loss: 0.0089
Epoch [1900/2000], Loss: 0.0156
Epoch [2000/2000], Loss: 0.0236
```
该输出显示了模型的平均训练损失在每个周期中的降低。最终,我们应该可以看到一个可视化图,其中红色圆圈表示原始数据,蓝色曲线表示模型的预测结果。
阅读全文