def predict(self, future_days=10): dataSetPast = self.dataset[-self.n_past: ] dataSetFuture = np.zeros((future_days, 2)) startDay = dataSetPast[-1][0]+1 dataSetFuture[:, 0] = np.arange(startDay, startDay+future_days) dataSetFull = np.concatenate((dataSetPast, dataSetFuture), axis=0) all_data = [] time_step = self.n_past for i in range(time_step, len(dataSetFull)): data_x = [] data_x.append( dataSetFull[i - time_step:i, :]) data_x = np.array(data_x) prediction = self.LSTModel.predict(data_x) all_data.append(prediction) dataSetFull[i, 1] = prediction
时间: 2023-06-26 09:06:50 浏览: 191
plot_cv_predict.zip_cross_val_predict_plot(cv._predict-python_预测
5星 · 资源好评率100%
这是一个用于预测未来天数股票价格的函数。它的输入参数是未来的天数,即需要预测的天数。函数首先根据历史数据(self.dataset)取出最近的self.n_past个数据点(默认为10个),然后构造一个大小为(future_days, 2)的全零数组(dataSetFuture),并设置未来天数的日期。将历史数据和未来数据合并成一个完整的数据集(dataSetFull)。接下来,将数据集划分为大小为self.n_past的时间步长,然后对于每个时间步长,将其作为输入数据,使用LSTModel模型进行预测,并将预测结果添加到all_data列表中。最后,将预测结果更新到dataSetFull中,并返回预测结果。
阅读全文