Naive Bayes - RDD-based API 输出结果
时间: 2023-08-09 13:09:56 浏览: 62
这里是一个 RDD-based API 的简单 Scala 代码示例,展示如何使用 Apache Spark 中的 MLlib 库来训练一个朴素贝叶斯分类器并输出结果:
```scala
import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}
// 创建 SparkConf 和 SparkContext
val conf = new SparkConf().setAppName("NaiveBayesExample")
val sc = new SparkContext(conf)
// 读取数据集
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 划分数据集为训练集和测试集
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1234L)
// 训练朴素贝叶斯分类器
val model = NaiveBayes.train(trainingData, lambda = 1.0, modelType = "multinomial")
// 在测试集上进行预测
val predictions = testData.map { point =>
val prediction = model.predict(point.features)
(prediction, point.label)
}
// 计算模型准确率
val accuracy = 1.0 * predictions.filter(r => r._1 == r._2).count() / testData.count()
println(s"Test set accuracy = $accuracy")
// 输出预测结果
predictions.take(10).foreach(println)
// 保存模型
model.save(sc, "myModelPath")
// 加载模型
val sameModel = NaiveBayesModel.load(sc, "myModelPath")
```
你可以将你的数据集替换文中的数据集路径,然后运行代码即可得到训练好的模型和预测结果。需要注意的是,在 RDD-based API 中,我们需要显式地调用 `predict` 方法来进行预测。