导入二维文件实现感知机绘制超平面的python代码
时间: 2023-12-28 17:53:00 浏览: 127
以下是导入二维文件实现感知机绘制超平面的Python代码:
```python
import numpy as np
import matplotlib.pyplot as plt
# 读取数据
data = np.loadtxt("data.txt", delimiter=",")
x = data[:, :-1]
y = data[:, -1]
# 初始化权重和偏置
w = np.zeros((1, x.shape[1]))
b = 0
# 定义训练函数
def train(x, y, w, b, lr, epochs):
for epoch in range(epochs):
for i in range(len(x)):
# 计算预测值
y_hat = np.dot(w, x[i].T) + b
# 更新权重和偏置
if y_hat * y[i] <= 0:
w += lr * y[i] * x[i]
b += lr * y[i]
return w, b
# 训练感知机
w, b = train(x, y, w, b, lr=0.1, epochs=100)
# 绘制超平面
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x[:,0], x[:,1], c=y)
x1_min, x1_max = x[:,0].min()-1, x[:,0].max()+1
x2_min, x2_max = x[:,1].min()-1, x[:,1].max()+1
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, 0.02), np.arange(x2_min, x2_max, 0.02))
Z = np.dot(w, np.array([xx1.ravel(), xx2.ravel()]).T) + b
Z = np.sign(Z)
Z = Z.reshape(xx1.shape)
ax.contourf(xx1, xx2, Z, alpha=0.1)
plt.show()
```
其中,`data.txt`为存储数据的文件,每一行代表一个数据点,最后一列为标签。在代码中,我们使用了`numpy`库来读取数据,并实现了一个`train`函数来训练感知机。最后,使用`matplotlib`库来绘制超平面。
阅读全文