from torch.utils.data import TensorDataset,DataLoader含义
时间: 2023-07-23 18:25:15 浏览: 96
DataLoader.py_torch数据_.DataLoader数据加载器_源码.zip
`TensorDataset` 和 `DataLoader` 是 PyTorch 中用于数据处理和批量加载的工具。
`TensorDataset` 可以将数据集作为参数传入,将每个样本和标签封装成一个元组,然后将所有元组存储在一个数据集中。
`DataLoader` 则可以将一个数据集分成多个小批量进行加载,方便训练模型。可以设置批量大小、是否随机打乱数据和是否使用多线程等参数。
下面是一个简单的例子:
```
import torch
from torch.utils.data import TensorDataset, DataLoader
# 创建数据集
x = torch.randn(100, 3)
y = torch.randn(100, 1)
dataset = TensorDataset(x, y)
# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# 遍历数据集
for batch_x, batch_y in dataloader:
print(batch_x.shape, batch_y.shape)
```
在这个例子中,我们先创建了一个包含 100 个样本和标签的数据集 `dataset`,然后使用 `DataLoader` 将其分成批量大小为 10 的小批量,并打乱数据。在遍历数据集时,每次输出一个小批量的样本和标签,其形状分别为 `(10, 3)` 和 `(10, 1)`。
阅读全文