使用pyspark mllib,拆分lpsa.data以data为后缀的数据集,写线性回归代码
时间: 2023-11-22 16:55:42 浏览: 82
可以使用以下代码拆分数据集:
```python
from pyspark import SparkContext
from pyspark.sql import SparkSession
sc = SparkContext()
spark = SparkSession(sc)
# 读取数据
data = spark.read.format("libsvm").load("lpsa.data")
# 拆分数据集
train_data = data.filter(data["label"] < 2.5)
test_data = data.filter(data["label"] >= 2.5)
# 将数据集缓存,加快计算速度
train_data.cache()
test_data.cache()
```
然后,可以使用以下代码实现线性回归:
```python
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
# 创建线性回归模型
lr = LinearRegression(featuresCol="features", labelCol="label")
# 训练模型
lr_model = lr.fit(train_data)
# 预测结果
predictions = lr_model.transform(test_data)
# 评估模型
evaluator = RegressionEvaluator(labelCol="label", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)
```
其中,`LinearRegression` 是 PySpark 中的线性回归模型,`RegressionEvaluator` 可以用于评估模型的性能。
阅读全文