spark实现决策树
时间: 2023-07-05 16:18:45 浏览: 72
Apache Spark提供了机器学习库MLlib,其中包括了决策树的实现。下面是一个简单的示例代码,展示如何使用Spark MLlib实现决策树:
```scala
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler}
// 加载数据
val data = spark.read.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load("data.csv")
// 准备数据
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
val assembler = new VectorAssembler()
.setInputCols(Array("feature1", "feature2", "feature3"))
.setOutputCol("features")
// 切分数据集
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
// 训练决策树模型
val dt = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("features")
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
val pipeline = new Pipeline()
.setStages(Array(labelIndexer, assembler, dt, labelConverter))
val model = pipeline.fit(trainingData)
// 在测试数据上评估模型
val predictions = model.transform(testData)
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${(1.0 - accuracy)}")
val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println(s"Learned classification tree model:\n ${treeModel.toDebugString}")
```
其中,我们使用了一个CSV文件作为数据源,并且使用了三个特征(feature1、feature2、feature3)和一个标签(label)。首先,我们使用StringIndexer将标签转换为数值,然后使用VectorAssembler将特征组合成一个向量。之后,我们将数据集切分为训练数据集和测试数据集,并使用DecisionTreeClassifier训练决策树模型。最后,我们使用MulticlassClassificationEvaluator评估模型的精度,并输出决策树模型的调试字符串。
上述代码只是一个简单的示例,实际使用时需要根据具体情况进行调整和优化。
阅读全文