pytorch生成三类数据集代码
时间: 2024-10-19 15:06:23 浏览: 36
在PyTorch中,生成三类数据集通常会涉及到创建自定义的数据生成器或者使用内置的数据集生成工具,例如`torch.utils.data.Dataset`和`torch DataLoader`。这里我们假设你想要创建一个简单的分类任务,其中包含三种类型的随机数据点。
首先,我们需要定义一个基础的数据集类,比如`MyDataset`:
```python
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, num_samples, data_size=64, n_classes=3):
self.data = torch.randn(num_samples, data_size)
self.labels = torch.randint(0, n_classes, (num_samples,))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx]
y = self.labels[idx]
return x, y
```
这个类创建了大小为`data_size`的随机浮点数作为特征数据,标签是0到`n_classes - 1`之间的整数。`__len__`和`__getitem__`方法分别用于获取数据集长度和单个样本。
然后你可以创建一个数据加载器来迭代这个数据集:
```python
batch_size = 32
dataset = MyDataset(num_samples=1000) # 更改为实际需要的样本数
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
for inputs, labels in dataloader:
print(inputs.shape, labels.shape)
```
这将返回一个批次的输入数据和对应的标签,可以用于训练模型。
阅读全文