def PrepareDataset(speed_matrix, BATCH_SIZE = 40, seq_len = 10, pred_len = 1, train_propotion = 0.7, valid_propotion = 0.2):
时间: 2024-05-22 15:14:39 浏览: 133
"""
Function to prepare dataset for training and validation
Args:
speed_matrix: numpy array of shape (num_days, num_timesteps, num_nodes)
BATCH_SIZE: batch size for training (default = 40)
seq_len: sequence length (number of timesteps) for input (default = 10)
pred_len: number of timesteps to predict (default = 1)
train_propotion: proportion of dataset to use for training (default = 0.7)
valid_propotion: proportion of dataset to use for validation (default = 0.2)
Returns:
train_data: PyTorch DataLoader object for training data
valid_data: PyTorch DataLoader object for validation data
"""
# Calculate number of days
num_days = speed_matrix.shape[0]
# Calculate number of nodes
num_nodes = speed_matrix.shape[2]
# Calculate total number of timesteps
total_timesteps = num_days * speed_matrix.shape[1]
# Create input and target sequences
input_seq = []
target_seq = []
# Loop through each day
for day in range(num_days):
# Loop through each timestep
for timestep in range(speed_matrix.shape[1] - seq_len - pred_len):
# Extract input sequence
input_seq.append(speed_matrix[day, timestep:timestep+seq_len, :])
# Extract target sequence
target_seq.append(speed_matrix[day, timestep+seq_len:timestep+seq_len+pred_len, :])
# Convert input and target sequences to numpy arrays
input_seq = np.array(input_seq)
target_seq = np.array(target_seq)
# Split dataset into training, validation, and testing sets
train_size = int(total_timesteps * train_propotion)
valid_size = int(total_timesteps * valid_propotion)
test_size = total_timesteps - train_size - valid_size
train_input = input_seq[:train_size]
train_target = target_seq[:train_size]
valid_input = input_seq[train_size:train_size+valid_size]
valid_target = target_seq[train_size:train_size+valid_size]
test_input = input_seq[train_size+valid_size:]
test_target = target_seq[train_size+valid_size:]
# Convert training and validation data to PyTorch DataLoader objects
train_data = DataLoader(TensorDataset(torch.Tensor(train_input), torch.Tensor(train_target)), batch_size=BATCH_SIZE, shuffle=True)
valid_data = DataLoader(TensorDataset(torch.Tensor(valid_input), torch.Tensor(valid_target)), batch_size=BATCH_SIZE, shuffle=True)
return train_data, valid_data
阅读全文