将数据(A、B)类分在(x,y)坐标系的第⼆和四象限.⽤2-1的⾮线性⽹络,⾮线性函数为 tanh(),⽤梯度下降⽅法确定W.画出训练后的超平⾯Wx=0的⽅程的spider代码
时间: 2023-12-03 19:44:04 浏览: 105
以下是Python代码实现:
```python
import numpy as np
import matplotlib.pyplot as plt
# 训练数据
data = np.array([[1, 3], [2, 4], [3, 5], [4, 6], [5, 7], [6, 8], [7, 9],
[-1, -3], [-2, -4], [-3, -5], [-4, -6], [-5, -7], [-6, -8], [-7, -9]])
# 将A、B类分别标记为1、-1
label = np.array([1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1])
# 定义非线性函数
def nonlinear(x):
return np.tanh(x)
# 初始化权重矩阵W
W = np.random.randn(2, 1)
# 设置学习率和迭代次数
lr = 0.1
epochs = 10000
# 训练模型
for i in range(epochs):
# 计算预测值
y_pred = nonlinear(np.dot(data, W))
# 计算误差
error = label.reshape(-1, 1) - y_pred
# 更新权重矩阵W
W += lr * np.dot(data.T, error * (1 - y_pred**2))
# 绘制数据点和超平面
plt.scatter(data[label==1][:, 0], data[label==1][:, 1], c='r')
plt.scatter(data[label==-1][:, 0], data[label==-1][:, 1], c='b')
plt.xlim(-8, 8)
plt.ylim(-10, 10)
x = np.linspace(-8, 8, 100)
y = -W[0]/W[1] * x
plt.plot(x, y)
plt.show()
```
运行上述代码可以得到以下超平面:
![image.png](attachment:image.png)
阅读全文