怎么使用 torch.utils.data.TensorDataset 类和 torch.utils.data.DataLoader 类来读取两个特征矩阵的数据
时间: 2023-07-23 20:23:45 浏览: 100
使用 `torch.utils.data.TensorDataset` 类可以将多个张量组合成数据集,然后使用 `torch.utils.data.DataLoader` 类来读取数据集。
假设我们有两个特征矩阵 `x1` 和 `x2`,以及对应的标签 `y`,可以将它们转换成 `TensorDataset` 对象,如下所示:
```python
import torch
from torch.utils.data import TensorDataset, DataLoader
# 定义特征矩阵和标签
x1 = torch.rand(100, 10)
x2 = torch.rand(100, 5)
y = torch.randint(0, 2, (100, ))
# 转换为 TensorDataset 对象
dataset = TensorDataset(x1, x2, y)
```
接下来,我们可以使用 `DataLoader` 类来读取数据集。`DataLoader` 类可以指定 `batch_size`(批量大小)、`shuffle`(是否打乱数据)、`num_workers`(使用多少个子进程来读取数据)等参数。
```python
# 定义 DataLoader 对象
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
# 读取数据
for batch in dataloader:
x1_batch, x2_batch, y_batch = batch
print(x1_batch.shape, x2_batch.shape, y_batch.shape)
```
这样,我们就可以使用 `DataLoader` 类来方便地读取两个特征矩阵的数据了。
阅读全文