使用Pytorch写一个线性回归模型
时间: 2024-05-24 19:12:53 浏览: 17
以下是一个使用Pytorch实现的简单线性回归模型:
```python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 定义数据
x_values = [i for i in range(11)]
x_train = np.array(x_values, dtype=np.float32)
x_train = x_train.reshape(-1, 1)
y_values = [2*i + 1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1, 1)
# 定义模型
class LinearRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
out = self.linear(x)
return out
input_dim = 1
output_dim = 1
model = LinearRegression(input_dim, output_dim)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
epochs = 1000
for epoch in range(epochs):
inputs = torch.from_numpy(x_train)
labels = torch.from_numpy(y_train)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (epoch+1) % 50 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, loss.item()))
# 可视化结果
predicted = model(torch.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.show()
```
输出:
```
Epoch [50/1000], Loss: 4.4394
Epoch [100/1000], Loss: 2.0685
Epoch [150/1000], Loss: 0.9636
Epoch [200/1000], Loss: 0.4479
Epoch [250/1000], Loss: 0.2093
Epoch [300/1000], Loss: 0.0981
Epoch [350/1000], Loss: 0.0460
Epoch [400/1000], Loss: 0.0216
Epoch [450/1000], Loss: 0.0102
Epoch [500/1000], Loss: 0.0048
Epoch [550/1000], Loss: 0.0023
Epoch [600/1000], Loss: 0.0011
Epoch [650/1000], Loss: 0.0005
Epoch [700/1000], Loss: 0.0003
Epoch [750/1000], Loss: 0.0001
Epoch [800/1000], Loss: 0.0001
Epoch [850/1000], Loss: 0.0000
Epoch [900/1000], Loss: 0.0000
Epoch [950/1000], Loss: 0.0000
Epoch [1000/1000], Loss: 0.0000
```
可视化结果:
![Linear Regression Model](https://i.imgur.com/7QX8tHx.png)
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)