tensordataset用法事例
时间: 2023-09-12 11:03:09 浏览: 120
关于dataset的实例
TensorDataset是一个PyTorch内置的数据集类,用于将数据集转换为张量形式,方便进行批处理和数据加载。以下是一个使用TensorDataset的示例:
```python
import torch
from torch.utils.data import TensorDataset, DataLoader
# 创建数据集
X = torch.tensor([[0,0], [0,1], [1,0], [1,1]], dtype=torch.float32) # 输入数据
y = torch.tensor([0, 1, 1, 0], dtype=torch.float32) # 目标数据
dataset = TensorDataset(X, y)
# 创建数据加载器
batch_size = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 打印数据
for i, (X_batch, y_batch) in enumerate(dataloader):
print(f"Batch {i+1}:")
print(f"X: {X_batch}")
print(f"y: {y_batch}")
```
上述代码中,首先创建了一个数据集`dataset`,其中输入数据为`X`,目标数据为`y`。接着使用`DataLoader`创建一个数据加载器,指定了批次大小`batch_size`和是否打乱数据`shuffle`。最后,使用数据加载器遍历数据集,每次加载一个批次数据`X_batch`和`y_batch`,并打印出来。
输出如下:
```
Batch 1:
X: tensor([[1., 0.],
[0., 0.]])
y: tensor([1., 0.])
Batch 2:
X: tensor([[1., 1.],
[0., 1.]])
y: tensor([0., 1.])
```
阅读全文