创建一个类似cifar10的数据集代码实现
时间: 2023-09-26 21:07:59 浏览: 93
以下是一个简单的代码实现,用于创建一个类似于CIFAR-10数据集的图像数据集。
```python
import numpy as np
import os
import cv2
class CIFAR10Dataset:
def __init__(self, data_dir):
self.data_dir = data_dir
self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
self.num_classes = len(self.classes)
self.train_images, self.train_labels = self.load_data('train')
self.test_images, self.test_labels = self.load_data('test')
def load_data(self, split):
images = []
labels = []
split_dir = os.path.join(self.data_dir, split)
for i, cls in enumerate(self.classes):
cls_dir = os.path.join(split_dir, cls)
for img_file in os.listdir(cls_dir):
img_path = os.path.join(cls_dir, img_file)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
images.append(img)
labels.append(i)
return np.array(images), np.array(labels)
```
上述代码会将数据集按照10个类别加载到内存中。其中,`load_data` 方法会遍历文件夹中的图像,并将其转换为RGB格式的Numpy数组。最终返回的 `train_images`,`train_labels`,`test_images`,`test_labels` 分别是训练集和测试集的图像和标签,可以用于训练和评估机器学习模型。
阅读全文