import argparse import numpy as np import pandas as pd from sklearn import model_selection from sklearn import preprocessing from sklearn import linear_model from sklearn import metrics import joblib from config import * def train(x_train, x_test, y_train, y_test): estimator = linear_model.Ridge() estimator.fit(x_train, y_train) print('梯度下降的权重系数是:', estimator.coef_) print('梯度下降的偏置是:', estimator.intercept_) joblib.dump(estimator, model_save_path) y_predict = estimator.predict(x_test) err = metrics.mean_squared_error(y_test, y_predict) print('梯度下降的误差率为:', err) def pred(x): estimator = joblib.load(model_save_path) x= np.array(x) predict = estimator.predict(x.reshape((x.shape[0], 1))) return predict def main(): parser = argparse.ArgumentParser(description='Demo of argparse') parser.add_argument('--method', type=str, default='train') args = parser.parse_args() method = args.method if method == 'train': df = pd.read_excel(file_name, dtype={ '年份': int, '值': float }) x = df['年份'].to_numpy() x = x.reshape((x.shape[0], 1)) y = df['值'].to_numpy() x_train, x_test, y_train, y_test=model_selection.train_test_split(x, y) train(x_train, x_test, y_train, y_test) elif method == 'test': predict = pred(pred_year).round(2) df = pd.DataFrame({ '预测年份': pred_year, '预测结果(单位:万吨)': predict }) print(df) df.to_excel(save_pred_path, index=False) else: print('wrong') if __name__ == '__main__': main()
时间: 2024-02-19 11:59:39 浏览: 138
Python 机器学习工具包SKlearn的安装与使用
5星 · 资源好评率100%
这段 Python 代码主要实现了一个基于 Ridge 回归模型的数据训练和预测程序。具体来说,代码分为以下几个部分:
1. 导入需要的库和模块,包括 argparse、numpy、pandas、scikit-learn 和 joblib。
2. 实现一个 train 函数,用于训练 Ridge 模型。该函数接收训练集和测试集的特征和标签数据,通过 scikit-learn 库中的 Ridge 类进行模型训练,输出模型的权重系数、偏置和误差率,并将训练好的模型保存到指定路径。
3. 实现一个 pred 函数,用于对给定的年份进行预测。该函数接收一个年份数组,从指定路径加载已经训练好的 Ridge 模型,对年份数组进行预测,并返回预测结果。
4. 实现一个 main 函数,用于控制整个程序的流程。该函数通过 argparse 模块解析命令行参数,从而决定程序要进行的操作类型。如果是训练操作,就从指定路径的 Excel 文件中读取数据,将年份和对应的值作为特征和标签,通过 train 函数进行训练。如果是预测操作,就调用 pred 函数进行预测,并将预测结果保存到指定路径的 Excel 文件中。如果命令行参数有误,则输出错误信息。
5. 在最后,通过 if __name__ == '__main__': 判断当前文件是否被作为模块导入,如果是,则不执行 main 函数,如果直接运行该文件,则执行 main 函数。
总的来说,这个程序使用 Ridge 回归模型对年份和对应的值进行训练,并且可以对未来的年份进行预测。
阅读全文