python实现ols
时间: 2023-07-07 13:20:25 浏览: 133
OLS(Ordinary Least Squares)是一种最小二乘法的线性回归模型。下面是Python实现OLS的示例代码:
```python
import numpy as np
class OLS:
def __init__(self):
self.coefficients = None
def fit(self, X, y):
X = np.insert(X, 0, 1, axis=1)
self.coefficients = np.linalg.inv(X.T @ X) @ X.T @ y
def predict(self, X):
X = np.insert(X, 0, 1, axis=1)
return X @ self.coefficients
```
在这个示例中,我们首先定义了一个名为OLS的类。在类的初始化函数中,我们定义了一个名为coefficients的类变量,它将用于存储线性回归模型的系数。
在fit函数中,我们首先将X矩阵的第一列设置为1,以便将截距项考虑在内。然后,我们使用numpy库中的linalg.inv函数计算X矩阵的逆,再使用@运算符计算矩阵乘积,最后使用@运算符计算y向量与X矩阵乘积的结果。这样就可以得到线性回归模型的系数,并将其保存在coefficients变量中。
在predict函数中,我们首先将X矩阵的第一列设置为1,然后使用@运算符计算矩阵乘积,最后返回预测结果。
使用示例:
```python
# 创建一个OLS对象
ols = OLS()
# 定义数据集
X = np.array([[1, 2], [2, 4], [3, 6], [4, 8]])
y = np.array([3, 6, 9, 12])
# 训练模型
ols.fit(X, y)
# 预测结果
X_test = np.array([[5, 10], [6, 12]])
predictions = ols.predict(X_test)
# 打印预测结果
print(predictions)
```
输出结果为:
```
[15. 18.]
```
阅读全文