LSTM 来预测车辆轨迹的 Python 代码
时间: 2023-10-14 07:04:58 浏览: 93
Python Pytorch实现基于LSTM的高速公路车辆轨迹预测 源码+数据集+说明
5星 · 资源好评率100%
以下是使用 LSTM 预测车辆轨迹的 Python 代码示例:
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LSTM
# 读取数据
dataset = pd.read_csv('car_trajectory.csv')
train_set = dataset.iloc[:, 1:2].values
# 特征缩放
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler(feature_range = (0, 1))
train_set_scaled = scaler.fit_transform(train_set)
# 创建输入和输出序列
X_train = []
y_train = []
for i in range(60, len(train_set_scaled)):
X_train.append(train_set_scaled[i-60:i, 0])
y_train.append(train_set_scaled[i, 0])
X_train, y_train = np.array(X_train), np.array(y_train)
# 转换为 LSTM 可接受的格式
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
# 创建 LSTM 模型
model = Sequential()
model.add(LSTM(units = 50, return_sequences = True, input_shape = (X_train.shape[1], 1)))
model.add(LSTM(units = 50))
model.add(Dense(units = 1))
model.compile(optimizer = 'adam', loss = 'mean_squared_error')
# 训练模型
model.fit(X_train, y_train, epochs = 50, batch_size = 32)
# 预测结果
inputs = train_set_scaled[len(train_set_scaled) - len(X_train):]
inputs = np.reshape(inputs, (inputs.shape[0], inputs.shape[1], 1))
predicted = model.predict(inputs)
predicted = scaler.inverse_transform(predicted)
# 可视化结果
plt.plot(train_set, color = 'red', label = 'Real car trajectory')
plt.plot(predicted, color = 'blue', label = 'Predicted car trajectory')
plt.title('Car Trajectory Prediction')
plt.xlabel('Time')
plt.ylabel('Position')
plt.legend()
plt.show()
```
需要注意的是,这里的 `car_trajectory.csv` 是已经处理好的车辆轨迹数据集,需要自行准备或者从公开数据集中获取。同时,代码中的参数和超参数需要根据具体情况进行调整和优化。
阅读全文