Data.TensorDataset(features,labels)
时间: 2024-02-18 16:05:41 浏览: 126
`torch.utils.data.TensorDataset` 是 PyTorch 中的一个数据集类,它用于将特征张量和标签张量打包成一个数据集。它接受两个参数:
- `features`:特征张量,即输入数据的特征部分,可以是一个或多个张量。
- `labels`:标签张量,即输入数据的标签部分,可以是一个或多个张量。
使用 `torch.utils.data.TensorDataset` 可以将特征和标签打包成一个数据集,然后使用数据加载器(如 `DataLoader`)对数据进行批处理和加载。
下面是一个使用 `torch.utils.data.TensorDataset` 的例子:
```python
import torch
from torch.utils.data import TensorDataset
# 创建特征张量和标签张量
features = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
labels = torch.tensor([1, 2, 3])
# 创建数据集
dataset = TensorDataset(features, labels)
# 使用数据加载器对数据进行批处理和加载
dataloader = DataLoader(dataset, batch_size=2)
# 遍历数据加载器
for inputs, targets in dataloader:
print(inputs, targets)
```
在这个例子中,我们首先创建了一个特征张量和一个标签张量,然后使用 `TensorDataset` 将它们打包成一个数据集。接下来,我们使用 `DataLoader` 对数据进行批处理和加载,然后遍历数据加载器以获取每个批次的输入和标签。
阅读全文