随机森林模型 scala代码 roc曲线 混淆矩阵
时间: 2023-08-06 12:17:36 浏览: 156
随机森林模型代码
好的,以下是使用Scala实现随机森林模型的代码示例:
```scala
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler}
import org.apache.spark.sql.{DataFrame, SparkSession}
object RandomForestModel {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("RandomForestModel")
.master("local[*]")
.getOrCreate()
// 读取数据
val data = spark.read.option("header", "true").option("inferSchema", "true").csv("path/to/data.csv")
// 特征工程
val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
val assembler = new VectorAssembler().setInputCols(Array("feature1", "feature2", "feature3")).setOutputCol("features")
// 划分数据集
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
// 定义随机森林模型
val rf = new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("features")
.setNumTrees(10)
// 将预测结果转换为原始标签
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
// 构建Pipeline
val pipeline = new Pipeline().setStages(Array(labelIndexer, assembler, rf, labelConverter))
// 训练模型
val model = pipeline.fit(trainingData)
// 预测
val predictions = model.transform(testData)
// 评估模型
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${(1.0 - accuracy)}")
// 输出ROC曲线
val roc = model.summary.roc
roc.show()
// 输出混淆矩阵
val confusionMatrix: DataFrame = model.summary
.asInstanceOf[org.apache.spark.ml.classification.RandomForestClassificationModel.Summary]
.confusionMatrix
confusionMatrix.show()
}
}
```
以上代码中,我们使用了Spark MLlib提供的随机森林分类器(RandomForestClassifier)来训练模型,并对测试集进行预测和评估。同时,我们还使用了评估指标中的ROC曲线和混淆矩阵来评价模型的性能,分别使用了`model.summary.roc`和`model.summary.confusionMatrix`来输出结果。
阅读全文