python实现:将手写体图像数据集进行切割,然后按照数字进行分类,构建手写体数字的分类数据集,阿拉伯数字和中文数字分别建立独立的数据集,以人为单位,当70%同学的手写体数字作为训练集,剩余30%同学的作为测试集,构建Pytorch的数据加载器,进行Batch方式的数据生成。的代码
时间: 2024-03-07 08:47:22 浏览: 159
好的,下面是一个可能的代码实现,供您参考:
```python
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
# 定义手写数字数据集类
class DigitDataset(Dataset):
def __init__(self, data_dir, img_size):
self.data_dir = data_dir
self.img_size = img_size
self.digits = os.listdir(data_dir)
self.num_digits = len(self.digits)
self.label_map = {digit: i for i, digit in enumerate(self.digits)}
self.images, self.labels = self.load_images()
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img = Image.open(self.images[index]).convert('L')
img = img.resize((self.img_size, self.img_size))
img = np.array(img).astype(np.float32) / 255.0
label = self.labels[index]
return img, label
def load_images(self):
images = []
labels = []
for digit in self.digits:
label = self.label_map[digit]
digit_dir = os.path.join(self.data_dir, digit)
for file_name in os.listdir(digit_dir):
file_path = os.path.join(digit_dir, file_name)
images.append(file_path)
labels.append(label)
return images, labels
# 切割手写数字图像
def crop_image(image_path, save_dir):
img = Image.open(image_path).convert('L')
img_array = np.array(img)
img_array[img_array < 128] = 0
img_array[img_array >= 128] = 255
img = Image.fromarray(img_array)
digits = []
for i in range(10):
left, upper, right, lower = get_digit_box(img_array, i)
if left >= 0 and upper >= 0 and right >= 0 and lower >= 0:
digit_img = img.crop((left, upper, right, lower))
digit_path = os.path.join(save_dir, str(i))
if not os.path.exists(digit_path):
os.makedirs(digit_path)
digit_img.save(os.path.join(digit_path, os.path.basename(image_path)))
digits.append(i)
return digits
# 获取图像中指定数字的边界框
def get_digit_box(img, digit):
height, width = img.shape
left, upper, right, lower = -1, -1, -1, -1
for i in range(height):
for j in range(width):
if img[i, j] == digit:
if left == -1 or j < left:
left = j
if upper == -1 or i < upper:
upper = i
if right == -1 or j > right:
right = j
if lower == -1 or i > lower:
lower = i
return left, upper, right, lower
# 划分数据集
def split_dataset(data_dir, train_dir, test_dir, test_ratio):
for digit in os.listdir(data_dir):
digit_dir = os.path.join(data_dir, digit)
X = np.array([os.path.join(digit_dir, file_name) for file_name in os.listdir(digit_dir)])
y = np.array([int(digit)] * len(X))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio, random_state=42)
for file_path in X_train:
os.makedirs(os.path.join(train_dir, digit), exist_ok=True)
os.rename(file_path, os.path.join(train_dir, digit, os.path.basename(file_path)))
for file_path in X_test:
os.makedirs(os.path.join(test_dir, digit), exist_ok=True)
os.rename(file_path, os.path.join(test_dir, digit, os.path.basename(file_path)))
# 构建数据加载器
def build_dataloader(data_dir, img_size, batch_size, test_ratio):
# 划分数据集
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')
if not os.path.exists(train_dir) or not os.path.exists(test_dir):
split_dataset(data_dir, train_dir, test_dir, test_ratio)
# 加载数据集
train_dataset = DigitDataset(train_dir, img_size)
test_dataset = DigitDataset(test_dir, img_size)
# 构建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
```
使用方法:
```python
# 切割手写数字图像,按照数字进行分类
data_dir = 'path/to/dataset'
save_dir = 'path/to/save'
for file_name in os.listdir(data_dir):
file_path = os.path.join(data_dir, file_name)
crop_image(file_path, save_dir)
# 构建数据加载器
img_size = 28
batch_size = 64
test_ratio = 0.3
train_loader, test_loader = build_dataloader(save_dir, img_size, batch_size, test_ratio)
```
其中,data_dir为原始手写数字图像数据集所在目录,save_dir为处理后的手写数字图像数据集所在目录。img_size为图像的大小,batch_size为每个Batch中的图像数量,test_ratio为测试集所占比例。在调用build_dataloader函数时,会自动划分数据集并构建数据加载器。
阅读全文