手动实现 train_test_split 函数
时间: 2023-10-13 17:12:22 浏览: 101
可以实现 train_test_split 函数,代码如下:
```python
import random
def train_test_split(X, y, test_size=0.2, random_seed=None):
"""
将数据集 X, y 按照 test_size 分割成训练集和测试集
"""
if random_seed:
# 设置随机种子
random.seed(random_seed)
# 计算测试集的数量
test_num = int(len(X) * test_size)
# 将索引打乱
index = list(range(len(X)))
random.shuffle(index)
# 获取训练集和测试集的索引
test_index = index[:test_num]
train_index = index[test_num:]
# 根据索引获取训练集和测试集
X_train = X[train_index]
y_train = y[train_index]
X_test = X[test_index]
y_test = y[test_index]
return X_train, X_test, y_train, y_test
```
其中,X 和 y 分别是数据集的特征和标签,test_size 表示测试集占总样本的比例(默认为 0.2),random_seed 表示随机种子(默认为 None)。该函数首先计算测试集的数量,然后将数据集的索引打乱,再根据打乱后的索引获取训练集和测试集的索引,最后根据索引获取训练集和测试集。
阅读全文