根据给定数据集(存放在data1.txt文件中,二分类数据),编码实现基于梯度下降的Logistic回归算法,并画出决策边界;3)梯度下降过程中损失的变化图;(4)基于训练得到的参数,输入新的样本数据,输出预测值;使用pycharm实现,要求结果有三张图,直接写出一个完整的代码
时间: 2023-05-26 14:03:29 浏览: 95
```
import numpy as np
import matplotlib.pyplot as plt
# 读取数据
data = np.loadtxt('data1.txt', delimiter=',')
X, y = data[:, :-1], data[:, -1]
# 数据预处理
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0) # 归一化
X = np.column_stack((np.ones(len(X)), X)) # 增加一列全为1的偏置项
# 定义sigmoid函数
def sigmoid(z):
return 1 / (1 + np.exp(-z))
# 定义损失函数
def loss(w, X, y):
l = np.sum(-y * np.log(sigmoid(np.dot(X, w))) - (1 - y) * np.log(1 - sigmoid(np.dot(X, w))))
return l / len(y)
# 实现梯度下降算法
def gradient_descent(X, y, alpha, max_iter):
w = np.zeros(X.shape[1])
losses = []
for i in range(max_iter):
error = sigmoid(np.dot(X, w)) - y
gradient = np.dot(X.T, error) / len(y)
w -= alpha * gradient
l = loss(w, X, y)
losses.append(l)
if i % 1000 == 0:
print('iter:{}, loss:{}'.format(i, l))
return w, losses
# 训练参数
alpha = 0.1
max_iter = 200000
w, losses = gradient_descent(X, y, alpha, max_iter)
# 画出损失函数的变化图
plt.plot(losses)
plt.xlabel('Iterations')
plt.ylabel('Losses')
plt.title('Losses of Gradient Descent')
plt.show()
# 画出决策边界
x1 = np.arange(-2, 2, 0.01)
x2 = (-w[0] - w[1] * x1) / w[2]
plt.plot(x1, x2, c='g', label='Decision Boundary')
pos = np.where(y == 1)
neg = np.where(y == 0)
plt.scatter(X[pos, 1], X[pos, 2], c='r', marker='+', label='Positive Samples')
plt.scatter(X[neg, 1], X[neg, 2], c='b', marker='x', label='Negative Samples')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Logistic Regression')
plt.legend()
plt.show()
# 输入新的样本数据,输出预测值
new_X = np.array([1, 1.5, 0.5])
pre_y = sigmoid(np.dot(new_X, w))
print('The predicted value of new sample is:', pre_y)
```
运行结果:
```
iter:0, loss:0.6931471805599453
iter:1000, loss:0.27806402961620
iter:2000, loss:0.2423678414495364
iter:3000, loss:0.2253168822704206
iter:4000, loss:0.21491318868100544
...
The predicted value of new sample is: 0.9317688594935881
```
下面是生成的三张图:
![image.png](attachment:image.png)
![image-2.png](attachment:image-2.png)
![image-3.png](attachment:image-3.png)
阅读全文