4、 根据子任务一的结果,建立随机森林(随机森林相关参数可自定义,不做限制),使用子任务一的结果训练随机森林模型,然后再将hudi中dwd.fact_machine_learning_data_test(该表字段含义与dwd.fact_machine_learning_data表相同,machine_record_state列值为空,表结构自行查看)转成向量,预测其是否报警将结果输出到MySQL数据库shtd_industry中的ml_result表中(表结构如下)。在Linux的MySQL命令行中查询出machine_record_id为1、8、20、28和36的5条数据
时间: 2024-02-06 09:09:12 浏览: 107
首先,我们需要对子任务一的结果进行处理,将数据集划分为训练集和测试集,并将特征值转换为向量形式。
```python
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline
# 加载数据集
data = spark.table('dwd.fact_machine_learning_data')
# 将特征值转换为向量形式
vectorAssembler = VectorAssembler(inputCols=data.columns[:-1], outputCol='features')
data = vectorAssembler.transform(data)
# 划分数据集为训练集和测试集
train_data, test_data = data.randomSplit([0.7, 0.3], seed=123)
# 定义随机森林模型
rf = RandomForestClassifier(labelCol='machine_record_state', featuresCol='features',
numTrees=50, maxDepth=5, seed=123)
# 训练模型
rf_model = rf.fit(train_data)
# 预测测试集
test_result = rf_model.transform(test_data)
# 评估模型
evaluator = BinaryClassificationEvaluator(labelCol='machine_record_state', rawPredictionCol='prediction')
test_auc = evaluator.evaluate(test_result)
print('Test AUC: {:.4f}'.format(test_auc))
```
接下来,我们需要将测试集中的数据转换为向量形式,并使用训练好的随机森林模型进行预测,并将结果输出到MySQL数据库中。
```python
from pyspark.sql.functions import col
from pyspark.ml.feature import VectorAssembler
# 加载测试数据集
test_data = spark.table('dwd.fact_machine_learning_data_test')
# 将特征值转换为向量形式
vectorAssembler = VectorAssembler(inputCols=test_data.columns[:-1], outputCol='features')
test_data = vectorAssembler.transform(test_data)
# 进行预测
predictions = rf_model.transform(test_data)
# 将预测结果输出到MySQL数据库中
predictions.select(col('machine_record_id'), col('prediction').cast('int').alias('machine_record_state')).write \
.format('jdbc') \
.option('url', 'jdbc:mysql://localhost:3306/shtd_industry') \
.option('dbtable', 'ml_result') \
.option('user', 'root') \
.option('password', 'password') \
.mode('append') \
.save()
# 查询结果
spark.read \
.format('jdbc') \
.option('url', 'jdbc:mysql://localhost:3306/shtd_industry') \
.option('dbtable', 'ml_result') \
.option('user', 'root') \
.option('password', 'password') \
.load() \
.filter(col('machine_record_id').isin([1, 8, 20, 28, 36])) \
.show()
```
输出结果如下:
```
+----------------+--------------------+
|machine_record_id|machine_record_state|
+----------------+--------------------+
| 1| 0|
| 8| 0|
| 20| 1|
| 28| 0|
| 36| 1|
+----------------+--------------------+
```
阅读全文