使用keras通过LSTM实现股价预测的代码
时间: 2024-05-07 14:18:30 浏览: 86
下面是使用Keras通过LSTM实现股价预测的代码:
```python
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import Dense, LSTM, Dropout
# 读取数据
df = pd.read_csv('stock.csv')
# 取出收盘价列
data = df.iloc[:, 4:5].values
# 数据归一化
sc = MinMaxScaler(feature_range=(0, 1))
data = sc.fit_transform(data)
# 划分训练集和测试集
training_size = int(len(data) * 0.7)
test_size = len(data) - training_size
train_data = data[0:training_size, :]
test_data = data[training_size:len(data), :]
# 将序列转换为监督学习数据
def create_dataset(dataset, look_back=1):
data_x, data_y = [], []
for i in range(len(dataset) - look_back - 1):
a = dataset[i:(i+look_back), 0]
data_x.append(a)
data_y.append(dataset[i+look_back, 0])
return np.array(data_x), np.array(data_y)
look_back = 60
train_x, train_y = create_dataset(train_data, look_back)
test_x, test_y = create_dataset(test_data, look_back)
# 转换为3维数据
train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[1], 1))
test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[1], 1))
# 定义LSTM模型
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(look_back, 1)))
model.add(Dropout(0.2))
model.add(LSTM(units=50, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(units=50))
model.add(Dropout(0.2))
model.add(Dense(units=1))
# 编译模型
model.compile(optimizer='adam', loss='mean_squared_error')
# 训练模型
model.fit(train_x, train_y, epochs=100, batch_size=32)
# 预测
predicted_stock_price = model.predict(test_x)
predicted_stock_price = sc.inverse_transform(predicted_stock_price)
# 可视化结果
plt.plot(test_data[look_back:], color='blue', label='Real Stock Price')
plt.plot(predicted_stock_price, color='red', label='Predicted Stock Price')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.show()
```
注意,这里的数据是使用Pandas从CSV文件中读取的,你需要将代码中的文件名修改为你自己的文件名。在训练模型时,我们使用了100个epochs,你可以根据需要调整这个值。在预测后,我们使用了`inverse_transform`方法将归一化的数据转换为原始数据。最后,我们使用Matplotlib将真实数据和预测结果进行可视化。
阅读全文