写一个 PCA-LSTM 股票预测的python 代码
时间: 2023-12-08 14:04:30 浏览: 107
这里是一个简单的PCA-LSTM股票预测的Python代码,使用了Python的Scikit-learn和Keras库:
```python
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import Dense, LSTM
# 读取数据
df = pd.read_csv('stock_data.csv')
# 数据预处理
scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(df.iloc[:, 1:].values)
# PCA降维
pca = PCA(n_components=1)
reduced_data = pca.fit_transform(scaled_data)
# 构建训练集和测试集
train_size = int(len(reduced_data) * 0.8)
train_data = reduced_data[:train_size]
test_data = reduced_data[train_size:]
# 构建LSTM模型
def create_dataset(dataset, look_back=1):
dataX, dataY = [], []
for i in range(len(dataset)-look_back-1):
a = dataset[i:(i+look_back), 0]
dataX.append(a)
dataY.append(dataset[i + look_back, 0])
return np.array(dataX), np.array(dataY)
look_back = 30
trainX, trainY = create_dataset(train_data, look_back=look_back)
testX, testY = create_dataset(test_data, look_back=look_back)
trainX = np.reshape(trainX, (trainX.shape[0], trainX.shape[1], 1))
testX = np.reshape(testX, (testX.shape[0], testX.shape[1], 1))
model = Sequential()
model.add(LSTM(50, input_shape=(look_back, 1)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(trainX, trainY, epochs=100, batch_size=1, verbose=2)
# 预测
trainPredict = model.predict(trainX)
testPredict = model.predict(testX)
trainPredict = scaler.inverse_transform(trainPredict)
testPredict = scaler.inverse_transform(testPredict)
# 可视化结果
import matplotlib.pyplot as plt
trainPredictPlot = np.empty_like(scaled_data)
trainPredictPlot[:, :] = np.nan
trainPredictPlot[look_back:len(trainPredict)+look_back, :] = trainPredict
testPredictPlot = np.empty_like(scaled_data)
testPredictPlot[:, :] = np.nan
testPredictPlot[len(trainPredict)+(look_back*2)+1:len(scaled_data)-1, :] = testPredict
plt.plot(scaler.inverse_transform(scaled_data))
plt.plot(trainPredictPlot)
plt.plot(testPredictPlot)
plt.show()
```
需要注意的一些事项:
- `stock_data.csv`是包含历史股票价格数据的csv文件。
- 该代码仅用于参考和学习目的,不应该用于实际投资决策。
- 真实的股票价格数据是非常复杂和不确定的,因此预测结果可能不够准确和可靠。
阅读全文