划分数据集python代码
时间: 2023-07-09 18:09:28 浏览: 198
以下是一个简单的Python代码示例,用于将数据集划分为训练集、验证集和测试集:
```python
import numpy as np
import os
# 设置数据集路径
data_dir = '/path/to/dataset'
# 定义训练,验证和测试集比例
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1
# 获取所有文件名和标签
data_files = os.listdir(data_dir)
data_labels = [filename.split('_')[0] for filename in data_files]
# 将标签转换为数字编码
label_set = set(data_labels)
label_dict = {label: i for i, label in enumerate(label_set)}
data_labels = [label_dict[label] for label in data_labels]
# 将数据和标签合并为元组
data_tuples = list(zip(data_files, data_labels))
# 打乱数据集
np.random.shuffle(data_tuples)
# 计算划分点
train_point = int(len(data_tuples) * train_ratio)
val_point = int(len(data_tuples) * (train_ratio + val_ratio))
# 划分数据集
train_data = data_tuples[:train_point]
val_data = data_tuples[train_point:val_point]
test_data = data_tuples[val_point:]
# 保存划分后的数据集
np.save(os.path.join(data_dir, 'train_data.npy'), train_data)
np.save(os.path.join(data_dir, 'val_data.npy'), val_data)
np.save(os.path.join(data_dir, 'test_data.npy'), test_data)
```
在此示例中,我们使用NumPy和os模块来处理数据集文件和标签,然后使用Python的随机模块打乱数据集。我们还定义了三个比例来分配数据集中的每个部分,并计算用于划分数据集的索引点。最后,我们将划分后的数据集保存为NumPy数组。您可以根据自己的需求和数据集结构修改此代码示例。
阅读全文