Spark入门必学:预测泰坦尼克号上的生还情况代码
时间: 2023-08-11 19:08:34 浏览: 103
好的,以下是预测泰坦尼克号上的生还情况的Spark代码示例:
```python
# 导入Spark相关的库
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
# 创建SparkSession
spark = SparkSession.builder.appName("Titanic Prediction").getOrCreate()
# 读取数据集
data = spark.read.csv("titanic.csv", header=True, inferSchema=True)
# 删除无用特征
data = data.drop("PassengerId", "Name", "Ticket", "Cabin")
# 将字符串类型的特征转换为数值类型
genderIndexer = StringIndexer(inputCol="Sex", outputCol="GenderIndex")
embarkedIndexer = StringIndexer(inputCol="Embarked", outputCol="EmbarkedIndex")
data = genderIndexer.fit(data).transform(data)
data = embarkedIndexer.fit(data).transform(data)
# 处理缺失值
data = data.fillna(0)
# 特征工程
assembler = VectorAssembler(inputCols=["Pclass", "Age", "SibSp", "Parch", "Fare", "GenderIndex", "EmbarkedIndex"], outputCol="features")
data = assembler.transform(data)
# 划分训练集和测试集
train, test = data.randomSplit([0.7, 0.3], seed=42)
# 构建决策树模型
dt = DecisionTreeClassifier(labelCol="Survived", featuresCol="features")
dtParamGrid = ParamGridBuilder().addGrid(dt.maxDepth, [3, 5, 7]).build()
dtCrossValidator = CrossValidator(estimator=dt, estimatorParamMaps=dtParamGrid, evaluator=BinaryClassificationEvaluator(), numFolds=5)
dtModel = dtCrossValidator.fit(train)
# 构建随机森林模型
rf = RandomForestClassifier(labelCol="Survived", featuresCol="features")
rfParamGrid = ParamGridBuilder().addGrid(rf.numTrees, [10, 20, 50]).addGrid(rf.maxDepth, [3, 5, 7]).build()
rfCrossValidator = CrossValidator(estimator=rf, estimatorParamMaps=rfParamGrid, evaluator=BinaryClassificationEvaluator(), numFolds=5)
rfModel = rfCrossValidator.fit(train)
# 评估模型
dtPredictions = dtModel.transform(test)
rfPredictions = rfModel.transform(test)
dtAUC = BinaryClassificationEvaluator(labelCol="Survived", rawPredictionCol="rawPrediction", metricName="areaUnderROC").evaluate(dtPredictions)
rfAUC = BinaryClassificationEvaluator(labelCol="Survived", rawPredictionCol="rawPrediction", metricName="areaUnderROC").evaluate(rfPredictions)
# 输出模型性能
print("Decision Tree AUC: " + str(dtAUC))
print("Random Forest AUC: " + str(rfAUC))
```
这个代码示例包括了数据读取、数据清洗、特征工程、模型构建和模型评估等多个步骤。其中,我们使用了决策树和随机森林两种分类算法来预测泰坦尼克号上的生还情况,并使用交叉验证和网格搜索来优化模型性能。最后,我们输出了两个模型的AUC值作为模型性能的指标。
阅读全文