基于pytroch简单的LSTM回归预测代码,数据集为一个excle文件,excle文件有两列
时间: 2024-02-11 14:09:38 浏览: 80
好的,以下是一个简单的基于PyTorch的LSTM回归预测代码示例,用于预测一个Excel文件中两列数据的趋势:
```python
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
# 数据预处理
data = pd.read_excel('data.xlsx')
scaler = MinMaxScaler()
data = scaler.fit_transform(data)
data = torch.FloatTensor(data)
# 定义 LSTM 模型
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, input):
hidden = (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))
lstm_out, hidden = self.lstm(input.view(len(input), 1, -1), hidden)
output = self.linear(lstm_out.view(len(input), -1))
return output[-1]
# 训练模型
input_size = 2
hidden_size = 4
output_size = 1
lr = 0.01
num_epochs = 100
model = LSTM(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
for i in range(len(data)-1):
input_data = data[i:i+1, :-1]
target = data[i+1:i+2, -1]
output = model(input_data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print('Epoch: {}/{}..........Loss: {:.6f}'.format(epoch, num_epochs, loss.item()))
# 预测数据
model.eval()
predictions = []
with torch.no_grad():
input_data = data[0, :-1]
for i in range(len(data)):
input_data = input_data.view(1, 1, -1)
output = model(input_data)
predictions.append(output.item())
input_data = torch.cat((input_data[0][1:], output), axis=0)
# 反归一化
predictions = np.array(predictions).reshape(-1, 1)
predictions = scaler.inverse_transform(predictions)
# 可视化预测结果
import matplotlib.pyplot as plt
plt.plot(predictions, label='Predictions')
plt.plot(data[:, -1], label='Original Data')
plt.legend()
plt.show()
```
请将 `data.xlsx` 替换为你自己的 Excel 文件路径。另外,你需要安装 Pandas、NumPy 和 Scikit-learn 库。
阅读全文