用torch切分文本数据集并将划分好的数据集文件打印出来的代码
时间: 2024-05-05 22:21:32 浏览: 164
以下是用PyTorch切分文本数据集并将划分好的数据集文件打印出来的示例代码:
```python
import torch
from torch.utils.data import Dataset, DataLoader
# 定义自定义数据集类
class TextDataset(Dataset):
def __init__(self, file_path):
with open(file_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 定义数据集路径和划分大小
data_path = 'data.txt'
train_ratio = 0.8
# 加载整个数据集
dataset = TextDataset(data_path)
# 计算划分大小
train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size
# 划分数据集
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# 打印划分好的训练集
print('Train dataset:')
for data in train_dataset:
print(data)
# 打印划分好的测试集
print('Test dataset:')
for data in test_dataset:
print(data)
```
上述代码中,首先定义了一个自定义数据集类`TextDataset`,用于加载文本数据集。然后,指定数据集路径和训练集大小,加载整个数据集并计算划分大小。接着,使用`torch.utils.data.random_split`函数将数据集随机划分为训练集和测试集。最后,分别打印划分好的训练集和测试集。
阅读全文