data= data[["本车加速度",'车头间距',"原车道前车速度",'本车速度']] data = data.dropna() print(data.tail()) cs = 1 if cs == 1: data = data batch_size = 2 # 批训练大小 epoch = 50 # 迭代次数 windows = 2 # 时间窗 else: pass #把数据处理成lstm接受的输入形式 y = data['本车加速度'] #归一化 data = np.array(data)/scale cut = round(test_ratio* data.shape[0]) amount_of_features=data.shape[1] lstm_input=[] data_temp=data for i in range(len(data_temp)-windows): lstm_input.append(data_temp[i:i+windows,:,:]) lstm_input=np.array(lstm_input) lstm_output=y[:-windows] lstm_output=np.array(lstm_output) x_train,y_train,x_test,y_test=\ lstm_input[:-cut,:,:,:],lstm_output[:-cut:],lstm_input[-cut:,:,:,:],lstm_output[-cut:]如果我要添加一列作为输入,代码如何修改
时间: 2024-01-13 19:04:22 浏览: 79
用于mysql8.0版本的MySql.Data.dll类库,直接引用使用
5星 · 资源好评率100%
如果你要添加一列作为输入,可以按照以下步骤修改代码:
1. 在之前的代码中,我们使用 `data` 变量来存储输入数据。首先,你需要将新的输入数据加入到 `data` 变量中,并将其加入到 `lstm_input` 列表中。具体地,假设你要加入的列名为 `new_feature`,你可以执行以下代码:
```
data["new_feature"] = new_feature_data
```
其中,`new_feature_data` 是一个 Numpy 数组,表示新的输入数据。这行代码将 `new_feature_data` 加入到 `data` 变量中,并将其命名为 `new_feature`。
2. 在之前的代码中,我们使用 `data_temp` 变量来存储输入数据。现在,由于 `data` 变量中包含了新的输入数据,我们需要将 `data` 赋值给 `data_temp`。具体地,你可以执行以下代码:
```
data_temp = data
```
3. 在之前的代码中,我们使用 `lstm_input` 列表来存储 LSTM 网络的输入数据。现在,由于 `data` 变量中包含了新的输入数据,我们需要将 `data_temp` 按照新的格式转换成 LSTM 网络的输入数据,并存储到 `lstm_input` 列表中。具体地,你可以执行以下代码:
```
lstm_input = []
for i in range(len(data_temp)-windows):
lstm_input.append(data_temp[i:i+windows,:,:])
lstm_input = np.array(lstm_input)
```
注意,在这个代码块中,我们使用 `data_temp` 而不是 `data` 变量,因为 `data` 变量中包含了新的输入数据,而这些数据还没有进行窗口切分。
4. 在之前的代码中,我们使用 `amount_of_features` 变量来存储输入数据的特征数量。现在,由于新的输入数据增加了一个特征,我们需要将 `amount_of_features` 的值加一。具体地,你可以执行以下代码:
```
amount_of_features = data.shape[1]
```
将这些修改后的代码添加到原来的代码中即可。
阅读全文