深度学习数据集制作pytorch代码
时间: 2023-11-16 08:03:51 浏览: 100
首先,你需要准备好你的数据集,格式可以是图片或文本。然后你可以使用PyTorch中的`Dataset`和`DataLoader`类来加载和处理数据,以下是一个简单的例子:
对于图片数据集,你可以使用`torchvision.datasets.ImageFolder`类,这个类会自动把数据集文件夹中的图片按照类别进行分类。假设你的数据集文件夹的路径是`./data`,其中包含10个子文件夹,每个子文件夹包含一种类别的图片,你可以这样加载数据集:
```python
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
# 定义数据转换
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载数据集
dataset = ImageFolder(root='./data', transform=transform)
# 定义数据加载器
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
```
对于文本数据集,你需要自己定义一个`Dataset`类来加载和处理数据,以下是一个简单的例子:
```python
import torch
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, text_file, label_file, vocab):
self.text = []
self.labels = []
self.vocab = vocab
# 读取文本和标签文件
with open(text_file, 'r') as f:
for line in f:
tokens = line.strip().split()
# 把单词转换成索引
self.text.append([self.vocab[token] for token in tokens])
with open(label_file, 'r') as f:
for line in f:
label = int(line.strip())
self.labels.append(label)
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
# 把文本和标签转换成张量
text = torch.tensor(self.text[index])
label = torch.tensor(self.labels[index])
return text, label
# 加载词汇表
vocab = {}
with open('vocab.txt', 'r') as f:
for i, line in enumerate(f):
token = line.strip()
vocab[token] = i
# 加载数据集
dataset = TextDataset('text.txt', 'labels.txt', vocab)
# 定义数据加载器
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
```
以上是两个简单的例子,你可以根据自己的需求来修改和扩展。
阅读全文