TensorDataset
时间: 2024-05-15 14:12:49 浏览: 101
TensorDataset是PyTorch中的一个类,用于将数据集的张量进行组合。在机器学习中,我们通常会把数据集划分成输入数据和目标数据,这些数据通常作为张量存储。TensorDataset将这些张量组合成一个数据集,使得每个索引位置的输入数据和目标数据可以一一对应。这样,在训练模型时,我们可以很方便地从数据集中取出一组输入数据和对应的目标数据进行训练。
下面是一个使用TensorDataset来组合数据集的示例代码:
```python
import torch
from torch.utils.data import TensorDataset
# 生成输入数据和目标数据的张量
inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
targets = torch.tensor([0, 1, 0, 1])
# 使用TensorDataset将输入数据和目标数据组合成一个数据集
dataset = TensorDataset(inputs, targets)
# 取出第一组输入数据和对应的目标数据
input_0, target_0 = dataset[0]
print('input_0:', input_0)
print('target_0:', target_0)
```
输出结果为:
```
input_0: tensor([1, 2, 3])
target_0: tensor(0)
```
可以看到,我们成功地将输入数据和目标数据组合成了一个数据集,并且可以很方便地从数据集中取出一组输入数据和对应的目标数据进行训练。
阅读全文