预测泰坦尼克号上的生还情况scala
时间: 2023-07-29 13:11:37 浏览: 106
泰坦尼克号幸存预测
好的,下面是预测泰坦尼克号上的生还情况的Scala代码:
1. 数据探索和清洗
```scala
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder().appName("Titanic").getOrCreate()
val df = spark.read.format("csv").option("header", "true").load("titanic_train.csv")
df.show()
// 查看数据集是否有缺失值
df.na.drop().show()
// 清洗数据集
val cleanedDf = df.na.drop().drop("PassengerId", "Name", "Ticket", "Cabin", "Embarked")
cleanedDf.show()
```
2. 特征工程
```scala
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorAssembler
// 对类别型特征进行编码
val genderIndexer = new StringIndexer().setInputCol("Sex").setOutputCol("Gender")
val indexedDf = genderIndexer.fit(cleanedDf).transform(cleanedDf)
indexedDf.show()
// 创建新的特征
val assembler = new VectorAssembler().setInputCols(Array("Pclass", "Age", "SibSp", "Parch", "Fare", "Gender")).setOutputCol("features")
val transformedDf = assembler.transform(indexedDf)
transformedDf.show()
```
3. 模型选择和训练
```scala
import org.apache.spark.ml.classification.LogisticRegression
// 训练逻辑回归模型
val lr = new LogisticRegression().setLabelCol("Survived").setFeaturesCol("features")
val lrModel = lr.fit(transformedDf)
lrModel.summary.accuracy
```
4. 模型评估和优化
```scala
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
// 定义超参数网格
val paramGrid = new ParamGridBuilder().addGrid(lr.regParam, Array(0.1, 0.01)).addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)).build()
// 定义交叉验证器
val cv = new CrossValidator().setEstimator(lr).setEvaluator(new BinaryClassificationEvaluator).setEstimatorParamMaps(paramGrid).setNumFolds(5)
// 进行交叉验证
val cvModel = cv.fit(transformedDf)
// 查看最佳模型的评估指标
cvModel.bestModel.summary.accuracy
```
5. 预测
```scala
val testDf = spark.read.format("csv").option("header", "true").load("titanic_test.csv")
testDf.show()
// 清洗测试数据集
val cleanedTestDf = testDf.na.drop().drop("PassengerId", "Name", "Ticket", "Cabin", "Embarked")
cleanedTestDf.show()
// 对测试数据集进行特征工程
val indexedTestDf = genderIndexer.fit(cleanedTestDf).transform(cleanedTestDf)
val transformedTestDf = assembler.transform(indexedTestDf)
transformedTestDf.show()
// 使用已训练好的模型进行预测
val predictions = cvModel.transform(transformedTestDf)
predictions.show()
```
以上就是使用Scala预测泰坦尼克号上的生还情况的代码。
阅读全文