逻辑回归 混交矩阵 scala代码
时间: 2023-06-30 19:12:51 浏览: 136
以下是使用Scala语言编写的逻辑回归模型混淆矩阵代码示例:
```scala
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{DataFrame, SparkSession}
object LogisticRegressionConfusionMatrixExample {
def main(args: Array[String]): Unit = {
// 创建 SparkSession
val spark = SparkSession.builder()
.appName("LogisticRegressionConfusionMatrixExample")
.master("local[*]")
.getOrCreate()
// 加载数据集
val data = spark.read.format("libsvm")
.load("data/sample_libsvm_data.txt")
// 将数据集分为训练集和测试集
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1234L)
// 创建逻辑回归模型
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)
// 训练模型
val lrModel = lr.fit(trainingData)
// 在测试集上进行预测
val predictions = lrModel.transform(testData)
// 创建评估器并计算准确率
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Accuracy: $accuracy")
// 计算混淆矩阵
val predictionAndLabel = predictions.select("prediction", "label")
.rdd.map(r => (r.getDouble(0), r.getDouble(1)))
val metrics = new MulticlassMetrics(predictionAndLabel)
val confusionMatrix = metrics.confusionMatrix
println(s"Confusion Matrix:\n$confusionMatrix")
// 关闭 SparkSession
spark.stop()
}
}
```
在这个示例中,我们使用逻辑回归模型对一个简单的数据集进行分类,然后使用MulticlassMetrics类计算混淆矩阵,并打印出来。代码中的数据集是Spark自带的一个样例数据集,可以在本地运行。
阅读全文