data.TensorDataset()
时间: 2024-10-14 14:03:17 浏览: 35
`TensorDataset`是PyTorch库中的一个基础数据集类,它用于存储和管理一组张量(Tensor)对,通常这些张量对应于模型的输入和目标变量。以下是它的主要用途和操作示例[^1]:
1. **创建数据集**:
```python
import torch
from torch.utils.data import TensorDataset
# 假设我们有输入特征features和对应的标签labels
features = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
labels = torch.tensor([0, 1, 0])
# 使用TensorDataset创建数据集
dataset = TensorDataset(features, labels)
```
2. **访问样本**:
数据集可以通过索引来访问其中的单个样本,每个样本都是一个包含输入和标签的元组。例如,第一个样本:
```python
sample = dataset[0]
print("First sample:", sample)
```
3. **遍历数据集**:
当用于训练时,可以使用`DataLoader`来按批次加载数据并迭代整个数据集:
```python
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=2) # 假设batch_size为2
for inputs, targets in dataloader:
# inputs: (batch_size, feature_shape)
# targets: (batch_size,)
```
阅读全文