pyspark mllib 拆分lspa.data数据集
时间: 2023-11-22 11:55:37 浏览: 103
如果你已经有了一个 LabeledPoint 格式的数据集,可以使用 PySpark MLlib 中的 randomSplit 方法将其拆分为训练集和测试集。以下是一个示例代码:
```python
from pyspark import SparkContext
from pyspark.mllib.regression import LabeledPoint
# 初始化 SparkContext
sc = SparkContext("local", "Splitting LabeledPoint dataset")
# 读取 LabeledPoint 数据集
data = sc.textFile("path/to/lspa.data")
parsedData = data.map(lambda line: LabeledPoint.parse(line))
# 将数据集拆分为训练集和测试集
trainData, testData = parsedData.randomSplit([0.7, 0.3], seed=1234)
# 打印训练集和测试集的大小
print("Training data size: ", trainData.count())
print("Test data size: ", testData.count())
# 关闭 SparkContext
sc.stop()
```
在这个示例代码中,我们首先使用 SparkContext 读取 LabeledPoint 数据集,然后使用 map 方法将每一行数据转换为 LabeledPoint 对象。接下来,我们使用 randomSplit 方法将数据集拆分为 70% 的训练集和 30% 的测试集,并指定一个种子以确保拆分结果的可重复性。最后,我们打印训练集和测试集的大小,并关闭 SparkContext。
需要注意的是,如果你的 LabeledPoint 数据集已经按照类别划分好了,为了避免训练集和测试集中的数据类别分布不均匀,你可以在 randomSplit 方法中使用 stratified 参数进行分层抽样。以下是一个示例代码:
```python
from pyspark import SparkContext
from pyspark.mllib.regression import LabeledPoint
# 初始化 SparkContext
sc = SparkContext("local", "Splitting LabeledPoint dataset with stratification")
# 读取 LabeledPoint 数据集
data = sc.textFile("path/to/lspa.data")
parsedData = data.map(lambda line: LabeledPoint.parse(line))
# 将数据集按照类别划分为训练集和测试集
trainData, testData = parsedData.randomSplit([0.7, 0.3], seed=1234, stratified=True)
# 打印训练集和测试集的大小
print("Training data size: ", trainData.count())
print("Test data size: ", testData.count())
# 关闭 SparkContext
sc.stop()
```
在这个示例代码中,我们在 randomSplit 方法中使用 stratified 参数进行分层抽样,确保训练集和测试集中的数据类别分布均匀。
阅读全文