pytorch 读取无标签数据集,并保存按比例划分的训练集和测试集
时间: 2024-02-03 16:03:38 浏览: 135
以下是使用 PyTorch 读取无标签数据集并保存按比例划分的训练集和测试集的示例代码:
```python
import os
import random
from shutil import copyfile
import torch
from torch.utils.data import Dataset
class UnlabeledDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.image_list = os.listdir(data_dir)
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
image_path = os.path.join(self.data_dir, self.image_list[idx])
image = Image.open(image_path).convert('RGB')
return image
def split_dataset(data_dir, train_ratio=0.8):
# Create train and test directories
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
# Get list of images
image_list = os.listdir(data_dir)
# Shuffle images
random.shuffle(image_list)
# Split dataset
train_size = int(train_ratio * len(image_list))
train_images = image_list[:train_size]
test_images = image_list[train_size:]
# Copy images to train directory
for image_name in train_images:
src_path = os.path.join(data_dir, image_name)
dst_path = os.path.join(train_dir, image_name)
copyfile(src_path, dst_path)
# Copy images to test directory
for image_name in test_images:
src_path = os.path.join(data_dir, image_name)
dst_path = os.path.join(test_dir, image_name)
copyfile(src_path, dst_path)
```
在上面的代码中,`UnlabeledDataset` 类是一个 PyTorch 数据集,它从无标签数据集目录中读取图像。`split_dataset` 函数将无标签数据集按比例划分为训练集和测试集,并将它们保存到相应的目录中。默认使用 0.8 的比例将数据集分为训练集和测试集。
使用示例:
```python
data_dir = '/path/to/unlabeled/dataset'
split_dataset(data_dir)
# Load train and test datasets
train_dataset = UnlabeledDataset(os.path.join(data_dir, 'train'))
test_dataset = UnlabeledDataset(os.path.join(data_dir, 'test'))
# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32)
```
上面的示例代码将无标签数据集按比例划分为训练集和测试集,并创建了相应的 PyTorch 数据加载器。
阅读全文