写一个python将THUCNews数据集读取成dataloader的代码
时间: 2024-02-25 07:53:21 浏览: 139
好的,假设THUCNews数据集已经下载并解压缩到了本地,可以使用如下代码将其读取成dataloader:
```python
import os
import torch
from torch.utils.data import Dataset, DataLoader
class THUCNewsDataset(Dataset):
def __init__(self, data_path, max_len):
self.max_len = max_len
self.label_to_idx = {}
self.data = []
label_idx = 0
for label_name in os.listdir(data_path):
label_path = os.path.join(data_path, label_name)
if not os.path.isdir(label_path):
continue
self.label_to_idx[label_name] = label_idx
label_idx += 1
for file_name in os.listdir(label_path):
file_path = os.path.join(label_path, file_name)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
content = content[:max_len]
self.data.append((content, label_name))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
content, label_name = self.data[index]
label_idx = self.label_to_idx[label_name]
return content, label_idx
def collate_fn(batch):
contents, labels = zip(*batch)
contents = [c.split(' ') for c in contents]
contents = [c[:max_len] for c in contents]
contents = [[1] + c + [2] for c in contents]
contents = torch.LongTensor(contents)
labels = torch.LongTensor(labels)
return contents, labels
def load_data_thucnews(batch_size, max_len):
train_data = THUCNewsDataset('path/to/THUCNews', max_len)
train_iter = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
vocab = None # THUCNews数据集不需要构建词汇表
return train_iter, vocab
```
其中,THUCNewsDataset类用于读取数据集,collate_fn函数用于对每个batch中的数据进行预处理,load_data_thucnews函数用于将预处理后的数据集转换为训练集,并返回train_iter和vocab。注意要将THUCNews数据集所在的路径替换成实际的路径,max_len是指截断后的文本长度。
阅读全文