通过Numpy实现线性回归
时间: 2024-05-29 12:06:33 浏览: 7
通过Numpy实现线性回归,可以分为以下几个步骤:
1. 导入Numpy库以及数据集
```
import numpy as np
data = np.genfromtxt('data.csv', delimiter=',')
```
2. 划分训练集和测试集
```
x_train = data[:-20, 0]
y_train = data[:-20, 1]
x_test = data[-20:, 0]
y_test = data[-20:, 1]
```
3. 定义模型参数
```
a = 0
b = 0
```
4. 定义学习率和迭代次数
```
learning_rate = 0.0001
num_iterations = 1000
```
5. 定义损失函数(均方误差)
```
def compute_cost(a, b, x_train, y_train):
total_cost = 0
N = len(x_train)
for i in range(N):
total_cost += (y_train[i] - (a * x_train[i] + b)) ** 2
return total_cost / N
```
6. 定义梯度下降函数
```
def gradient_descent(a, b, x_train, y_train, learning_rate, num_iterations):
N = len(x_train)
for i in range(num_iterations):
grad_a = 0
grad_b = 0
for j in range(N):
grad_a += (2/N) * x_train[j] * ((a * x_train[j] + b) - y_train[j])
grad_b += (2/N) * ((a * x_train[j] + b) - y_train[j])
a -= learning_rate * grad_a
b -= learning_rate * grad_b
return a, b
```
7. 训练模型并得到最优的模型参数
```
a, b = gradient_descent(a, b, x_train, y_train, learning_rate, num_iterations)
```
8. 评估模型性能并可视化结果
```
import matplotlib.pyplot as plt
plt.scatter(x_train, y_train)
plt.plot(x_train, a * x_train + b, 'r')
plt.show()
print('Train cost:', compute_cost(a, b, x_train, y_train))
print('Test cost:', compute_cost(a, b, x_test, y_test))
```
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)