PyTorch搭建LSTM模型,输入csv数据,划分训练集和测试集,实现多变量时间序列预测并且将预测结果可视化
时间: 2024-04-29 17:25:55 浏览: 29
首先,需要导入PyTorch和其他必要的库:
```python
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
```
接着,读入csv数据并进行归一化处理:
```python
data = pd.read_csv('data.csv', index_col=0)
scaler = MinMaxScaler()
data = scaler.fit_transform(data)
```
然后,定义划分训练集和测试集的函数:
```python
def train_test_split(data, test_ratio=0.2):
test_size = int(len(data) * test_ratio)
train_size = len(data) - test_size
train_data = data[:train_size]
test_data = data[train_size:]
return train_data, test_data
```
接着,定义LSTM模型:
```python
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, input):
lstm_out, _ = self.lstm(input.view(len(input), 1, -1))
prediction = self.linear(lstm_out.view(len(input), -1))
return prediction[-1]
```
接着,定义训练函数:
```python
def train(model, train_data, optimizer, criterion, num_epochs):
for epoch in range(num_epochs):
for i in range(len(train_data) - 1):
x = train_data[i:i+1,:-1]
y = train_data[i+1:i+2,-1]
optimizer.zero_grad()
output = model(torch.tensor(x).float())
loss = criterion(output, torch.tensor(y).float())
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print('Epoch: {}, Loss: {:.5f}'.format(epoch, loss.item()))
```
接着,定义测试函数:
```python
def test(model, test_data):
predictions = []
for i in range(len(test_data) - 1):
x = test_data[i:i+1,:-1]
output = model(torch.tensor(x).float())
predictions.append(output.item())
return predictions
```
最后,将预测结果可视化:
```python
def plot_predictions(predictions, test_data):
test_y = test_data[1:,-1]
plt.plot(test_y, label='True Data')
plt.plot(predictions, label='Predictions')
plt.legend()
plt.show()
```
完整代码如下:
```python
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
data = pd.read_csv('data.csv', index_col=0)
scaler = MinMaxScaler()
data = scaler.fit_transform(data)
def train_test_split(data, test_ratio=0.2):
test_size = int(len(data) * test_ratio)
train_size = len(data) - test_size
train_data = data[:train_size]
test_data = data[train_size:]
return train_data, test_data
train_data, test_data = train_test_split(data)
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, input):
lstm_out, _ = self.lstm(input.view(len(input), 1, -1))
prediction = self.linear(lstm_out.view(len(input), -1))
return prediction[-1]
model = LSTM(input_size=3, hidden_size=64, output_size=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
num_epochs = 100
train(model, train_data, optimizer, criterion, num_epochs)
predictions = test(model, test_data)
plot_predictions(predictions, test_data)
```
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)