def PrepareDataset(speed_matrix, BATCH_SIZE = 40, seq_len = 10, pred_len = 1, train_propotion = 0.7, valid_propotion = 0.2):
时间: 2024-05-22 15:15:19 浏览: 11
"""
Prepare the dataset for training and testing.
Inputs:
speed_matrix: the speed matrix of shape (num_samples, num_nodes, input_dim)
BATCH_SIZE: the batch size for training
seq_len: the length of the input sequence
pred_len: the length of the prediction sequence
train_propotion: the proportion of the data used for training
valid_propotion: the proportion of the data used for validation
Outputs:
train_loader: the data loader for training
valid_loader: the data loader for validation
test_loader: the data loader for testing
"""
# Get the number of samples and the number of nodes
num_samples, num_nodes, input_dim = speed_matrix.shape
# Calculate the number of training, validation, and testing samples
num_train = int(num_samples * train_propotion)
num_valid = int(num_samples * valid_propotion)
num_test = num_samples - num_train - num_valid
# Split the data into training, validation, and testing sets
train_data = speed_matrix[:num_train]
valid_data = speed_matrix[num_train:num_train + num_valid]
test_data = speed_matrix[num_train + num_valid:]
# Normalize the data using the mean and standard deviation of the training set
mean = np.mean(train_data)
std = np.std(train_data)
train_data = (train_data - mean) / std
valid_data = (valid_data - mean) / std
test_data = (test_data - mean) / std
# Define the data loader for training data
train_loader = DataLoader(
dataset=TrafficDataset(train_data, seq_len, pred_len),
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4,
pin_memory=True
)
# Define the data loader for validation data
valid_loader = DataLoader(
dataset=TrafficDataset(valid_data, seq_len, pred_len),
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4,
pin_memory=True
)
# Define the data loader for testing data
test_loader = DataLoader(
dataset=TrafficDataset(test_data, seq_len, pred_len),
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=4,
pin_memory=True
)
return train_loader, valid_loader, test_loader
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![mp4](https://img-home.csdnimg.cn/images/20210720083504.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)