为以下代码添加注释batch_size = 64 # Create data loaders. train_dataloader = DataLoader(training_data, batch_size=batch_size) test_dataloader = DataLoader(test_data, batch_size=batch_size) for X, y in test_dataloader: print(f"Shape of X [N, C, H, W]: {X.shape}") print(f"Shape of y: {y.shape} {y.dtype}") break
时间: 2023-07-14 10:12:34 浏览: 101
pytorch dataloader 取batch_size时候出现bug的解决方式
# 设置每个 batch 的大小为 64
batch_size = 64
# 创建训练集和测试集的数据加载器
# training_data 和 test_data 是已经处理好的数据集
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
# 循环遍历测试集的数据加载器
for X, y in test_dataloader:
# X 是输入数据,其形状为 [N, C, H, W],其中 N 为 batch_size,C 为通道数,H 和 W 分别为高度和宽度
print(f"Shape of X [N, C, H, W]: {X.shape}")
# y 是标签数据,其形状和数据类型
print(f"Shape of y: {y.shape} {y.dtype}")
# 只打印一次,即打印第一个 batch 的数据
break
阅读全文