ShuffleSplit函数及其参数介绍
时间: 2023-12-10 19:02:46 浏览: 24
`ShuffleSplit` 是一个交叉验证方法,它将数据集随机打乱后,划分为训练集和测试集。该方法可以用于评估模型的稳定性和泛化能力。
`ShuffleSplit` 函数的参数如下:
- `n_splits`: 它表示划分数据集的次数,默认为 10。
- `test_size`: 它表示测试集的大小,可以为整数(表示样本数)或浮点数(表示样本比例),默认为 0.1。
- `train_size`: 它表示训练集的大小,可以为整数或浮点数,如果没有指定,则默认为 `1 - test_size`。
- `random_state`: 它表示随机数生成器的种子,可以为整数或 `RandomState` 实例,用于控制随机打乱的结果。
- `shuffle`: 它表示是否对数据集进行随机打乱,默认为 `True`。
- `indices`: 它表示是否返回索引,如果为 `True`,则返回每次划分的训练集和测试集的索引,否则返回样本数据。
例如,以下代码演示了如何使用 `ShuffleSplit` 划分数据集:
```python
from sklearn.model_selection import ShuffleSplit
import numpy as np
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([0, 1, 0, 1, 0])
ss = ShuffleSplit(n_splits=3, test_size=0.2, random_state=0)
for train_index, test_index in ss.split(X, y):
print("TRAIN:", train_index, "TEST:", test_index)
```
输出结果如下:
```
TRAIN: [2 1 0 4] TEST: [3]
TRAIN: [4 0 1 2] TEST: [3]
TRAIN: [3 2 4 1] TEST: [0]
```
可以看到,数据集被随机划分为三份,每次划分的训练集和测试集的索引不同。