已知员工工龄与平均工资之间有一定关系,定义两个数组分别存放工龄数据和平均工资数据,训练线性回归模型,绘制图像的python代码
时间: 2024-10-08 10:13:00 浏览: 28
首先,你需要安装一些必要的库,如`pandas`, `numpy`, `matplotlib`, 和 `sklearn`来进行数据处理、模型训练以及绘图。以下是使用Python定义数组并训练线性回归模型的基本步骤:
```python
# 导入所需的库
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
# 假设你有两列数据,分别是工龄(years_of_experience)和平均工资(average_salary)
data = {
'years_of_experience': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # 填充实际的数据
'average_salary': [20000, 22000, 24000, 26000, 28000, 30000, 32000, 34000, 36000, 38000] # 填充实际的数据
}
df = pd.DataFrame(data)
# 将数据拆分为特征(X)和目标变量(y)
X = df['years_of_experience'].values.reshape(-1, 1)
y = df['average_salary'].values.reshape(-1, 1)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建并训练线性回归模型
model = LinearRegression()
model.fit(X_train, y_train)
# 预测和评估
y_pred = model.predict(X_test)
print('模型系数:', model.coef_)
print('截距:', model.intercept_)
# 绘制散点图和回归线
plt.scatter(X_train, y_train, label='训练数据')
plt.plot(X_train, model.predict(X_train), color='red', label='预测线')
plt.xlabel('工龄')
plt.ylabel('平均工资')
plt.legend()
plt.title('工龄与平均工资的关系')
plt.show()
阅读全文