python实现1.将手写体图像数据集进行切割,然后按照数字进行分类,构建手写体数字的分类数据集。 阿拉伯数字和中文数字分别建立独立的数据集。 2.以人为单位,当70%同学的手写体数字作为训练集,剩余30%同学的作为测试集。 3.构建Pytorch的数据加载器,进行Batch方式的数据生成。
时间: 2024-03-07 16:46:36 浏览: 64
1. 实现手写体图像数据集切割、分类
下面是一个简单的 Python 实现,可以将手写数字图像数据集按照数字进行分类,并保存到相应的文件夹中:
```python
import os
import shutil
from PIL import Image
data_dir = '/path/to/handwritten_digits_dataset'
output_dir = '/path/to/classified_dataset'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for filename in os.listdir(data_dir):
if filename.endswith('.png'):
img = Image.open(os.path.join(data_dir, filename))
digit = filename.split('.')[0][-1] # 获取图像文件名中的数字
if digit.isdigit():
output_subdir = os.path.join(output_dir, digit)
if not os.path.exists(output_subdir):
os.makedirs(output_subdir)
output_path = os.path.join(output_subdir, filename)
shutil.copyfile(os.path.join(data_dir, filename), output_path)
```
2. 划分训练集和测试集
下面是一个简单的 Python 实现,可以将手写数字数据集划分为训练集和测试集,比例为 7:3:
```python
import os
import random
import shutil
data_dir = '/path/to/classified_dataset'
train_dir = '/path/to/train_dataset'
test_dir = '/path/to/test_dataset'
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(test_dir):
os.makedirs(test_dir)
for digit_dir in os.listdir(data_dir):
digit_path = os.path.join(data_dir, digit_dir)
digit_files = os.listdir(digit_path)
random.shuffle(digit_files)
num_train = int(len(digit_files) * 0.7)
train_files = digit_files[:num_train]
test_files = digit_files[num_train:]
for filename in train_files:
input_path = os.path.join(digit_path, filename)
output_path = os.path.join(train_dir, filename)
shutil.copyfile(input_path, output_path)
for filename in test_files:
input_path = os.path.join(digit_path, filename)
output_path = os.path.join(test_dir, filename)
shutil.copyfile(input_path, output_path)
```
3. 构建 Pytorch 数据加载器
下面是一个简单的 Pytorch 数据加载器实现,可以对手写数字训练集和测试集进行 Batch 方式的数据生成:
```python
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
class HandwrittenDigitsDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.digit_filenames = []
for digit_dir in os.listdir(data_dir):
digit_path = os.path.join(data_dir, digit_dir)
digit_filenames = os.listdir(digit_path)
self.digit_filenames += [(os.path.join(digit_path, filename), int(digit_dir)) for filename in digit_filenames]
def __len__(self):
return len(self.digit_filenames)
def __getitem__(self, index):
filename, digit = self.digit_filenames[index]
img = Image.open(filename)
img = img.convert('L') # 转换为灰度图像
img = img.resize((28, 28)) # 调整图像大小
img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0 # 转换为张量,并进行归一化
return img, digit
train_dataset = HandwrittenDigitsDataset('/path/to/train_dataset')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = HandwrittenDigitsDataset('/path/to/test_dataset')
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
```
这里的 `HandwrittenDigitsDataset` 类实现了 Pytorch 的 `Dataset` 接口,可以在 Pytorch 的数据加载器中使用。在 `__init__` 方法中,我们遍历了训练集或测试集文件夹中的所有文件,并保存了文件路径和对应的数字标签。在 `__getitem__` 方法中,我们读取了图像文件,并将其转换为灰度图像、调整大小、转换为张量、进行归一化等操作。在数据加载器中,我们可以指定每个 Batch 的大小、是否随机打乱数据等参数。
阅读全文