生成一个lstm预测降水的代码
时间: 2024-01-26 19:03:17 浏览: 119
以下是一个基于LSTM的降水预测代码,使用的是Python和Keras深度学习库:
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import Dense, LSTM
# 读取数据
data = pd.read_csv('rainfall_data.csv')
# 将日期转换为时间序列
data['Date'] = pd.to_datetime(data['Date'], format='%Y-%m-%d')
data.index = data['Date']
# 创建新的数据帧,只包含日期和降水数据
new_data = pd.DataFrame(index=range(0,len(data)),columns=['Date', 'Rainfall'])
for i in range(0,len(data)):
new_data['Date'][i] = data['Date'][i]
new_data['Rainfall'][i] = data['Rainfall'][i]
# 将日期设置为索引
new_data.index = new_data.Date
new_data.drop('Date', axis=1, inplace=True)
# 分割数据集为训练和验证集
dataset = new_data.values
train = dataset[0:200,:]
valid = dataset[200:,:]
# 将数据集缩放到0到1的范围内
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(dataset)
# 创建训练数据集
x_train, y_train = [], []
for i in range(60,len(train)):
x_train.append(scaled_data[i-60:i,0])
y_train.append(scaled_data[i,0])
x_train, y_train = np.array(x_train), np.array(y_train)
# 将数据重塑成LSTM可接受的三维格式
x_train = np.reshape(x_train, (x_train.shape[0],x_train.shape[1],1))
# 创建和拟合LSTM网络
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(x_train.shape[1],1)))
model.add(LSTM(units=50))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.fit(x_train, y_train, epochs=1, batch_size=1, verbose=2)
# 创建测试数据集
inputs = new_data[len(new_data) - len(valid) - 60:].values
inputs = inputs.reshape(-1,1)
inputs = scaler.transform(inputs)
X_test = []
for i in range(60,inputs.shape[0]):
X_test.append(inputs[i-60:i,0])
X_test = np.array(X_test)
# 将数据重塑成LSTM可接受的三维格式
X_test = np.reshape(X_test, (X_test.shape[0],X_test.shape[1],1))
# 预测测试集数据
closing_price = model.predict(X_test)
closing_price = scaler.inverse_transform(closing_price)
# 绘制预测结果
train = new_data[:200]
valid = new_data[200:]
valid['Predictions'] = closing_price
plt.plot(train['Rainfall'])
plt.plot(valid[['Rainfall','Predictions']])
plt.show()
```
这个代码假设你的数据文件是名为 `rainfall_data.csv` 的CSV文件,其中包含一个名为 `Date` 的日期列和一个名为 `Rainfall` 的降水数据列。这个代码使用LSTM来预测降水数据,并将其可视化为训练数据集和预测结果。
阅读全文