使用pytorch的api进行线性回归,其中y_true = 4 * x + 3.5,请用代码书写完成,最后绘制图片
时间: 2024-02-03 13:14:26 浏览: 72
以下是使用 PyTorch 实现的线性回归模型,同时绘制出训练数据和模型预测数据的图形:
``` python
import torch
import numpy as np
import matplotlib.pyplot as plt
# 定义训练数据
x_train = np.arange(-10, 10, 0.1)
y_train = 4 * x_train + 3.5
# 转换数据类型为 Tensor
x_train = torch.Tensor(x_train)
y_train = torch.Tensor(y_train)
# 定义模型
class LinearRegression(torch.nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
out = self.linear(x)
return out
# 实例化模型和损失函数
model = LinearRegression()
criterion = torch.nn.MSELoss()
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
num_epochs = 500
for epoch in range(num_epochs):
# 前向传播
y_pred = model(x_train)
# 计算损失函数并输出
loss = criterion(y_pred, y_train)
if (epoch+1) % 50 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 绘制图形
plt.scatter(x_train.numpy(), y_train.numpy(), color='blue')
plt.plot(x_train.numpy(), model(x_train).detach().numpy(), color='red')
plt.show()
```
运行以上代码,就可以看到绘制出的训练数据和模型预测数据的图形了。
阅读全文