def draw(x, y): y = y[0] x_po = x[np.where(y == 1)] x_ne = x[np.where(y == 0)] #绘制散点图 ax = plt.axes(projection = "3d") x_1 = x_po[0,:] y_1 = x_po[1,:] z_1 = x_po[2,:] x_2 = x_ne[0,:] y_2 = x_ne[1,:] z_2 = x_ne[2,:] #p = 0.5的面 a,b,c,d = w x = np.linspace(-3, 3, 3) y = np.linspace(-3, 3, 3) x_3, y_3 = np.meshgrid(x, y) z_3 = -(a * x_3 + b * y_3 +d) / c ax.scatter(x_1, y_1, z_1, c = "r", label = "Positive") ax.scatter(x_2, y_2, z_2, c = "b", label = "Negative") ax.plot_surface(x_3, y_3, z_3, alpha = 0.5) plt.legend() plt.show() draw(X_train, y_train)
时间: 2023-05-19 17:03:44 浏览: 101
这是一个Python函数,名为draw,它有两个参数x和y。函数的作用是将数据x和标签y中的正例和负例分别提取出来,用于可视化或其他操作。具体来说,函数首先将标签y的第一个元素赋值给变量y,然后使用numpy库中的where函数,找出标签y中等于1的位置,将对应位置的数据x提取出来,存储在变量x_po中;同样地,找出标签y中等于0的位置,将对应位置的数据x提取出来,存储在变量x_ne中。最后,函数返回x_po和x_ne两个变量。
阅读全文