pyspark读取本地文件实现线性回归代码,及数据集
时间: 2023-12-09 21:03:38 浏览: 110
以下是一个简单的 PySpark 线性回归代码实例,其中使用了本地文件作为数据集。
数据集:
```
1.0,2.0
2.0,3.0
3.0,4.0
4.0,5.0
5.0,6.0
```
代码实现:
```python
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler
# 创建 SparkSession
spark = SparkSession.builder.appName("LinearRegressionExample").getOrCreate()
# 加载数据集
data = spark.read.format("csv").option("header", "false").option("inferSchema", "true").load("data.csv")
data.show()
# 数据预处理
assembler = VectorAssembler(inputCols=["_c0"], outputCol="features")
data = assembler.transform(data).select("features", "_c1").withColumnRenamed("_c1", "label")
data.show()
# 划分数据集
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# 建立线性回归模型
lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
# 训练模型
model = lr.fit(trainingData)
# 预测结果
predictions = model.transform(testData)
predictions.show()
# 计算模型评估指标
from pyspark.ml.evaluation import RegressionEvaluator
evaluator = RegressionEvaluator(predictionCol="prediction", labelCol="label", metricName="r2")
r2 = evaluator.evaluate(predictions)
print("R Squared (R2) on test data = %g" % r2)
# 关闭 SparkSession
spark.stop()
```
注:此代码仅供参考,实际使用时需根据数据集和需求进行适当修改。
阅读全文