基于pytorch写一段回归预测代码
时间: 2024-02-11 20:09:34 浏览: 36
以下是使用 PyTorch 实现神经网络进行回归预测的 Python 代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 创建一些随机的数据用于回归预测
X = np.sort(5 * np.random.rand(80, 1), axis=0)
y = np.sin(X).ravel()
# 将数据转换为 PyTorch 张量
X_tensor = torch.from_numpy(X).float()
y_tensor = torch.from_numpy(y).float()
# 定义神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 10)
self.fc2 = nn.Linear(10, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
net = Net()
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)
# 训练模型
for epoch in range(1000):
optimizer.zero_grad()
y_pred = net(X_tensor)
loss = criterion(y_pred, y_tensor)
loss.backward()
optimizer.step()
# 预测
y_pred_tensor = net(X_tensor)
y_pred = y_pred_tensor.detach().numpy()
# 可视化结果
import matplotlib.pyplot as plt
plt.scatter(X, y, color='darkorange', label='data')
plt.plot(X, y_pred, color='navy', label='Neural Network Regression')
plt.legend()
plt.show()
```
这段代码首先使用 Numpy 创建了一些随机的数据用于回归预测,并将数据转换为 PyTorch 张量。然后,定义了一个包含两个全连接层的神经网络模型,以及损失函数和优化器。接下来,使用训练数据对模型进行训练,并使用训练好的模型对训练数据进行预测。最后,使用 Matplotlib 将预测结果可视化出来。
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)