def plotDescisionBoundary(X, y, theta):具体代码
时间: 2024-05-12 08:19:58 浏览: 110
这是一个基本的决策边界绘制函数,它使用matplotlib库绘制二维数据点和线性决策边界。
```python
import numpy as np
import matplotlib.pyplot as plt
def plotDescisionBoundary(X, y, theta):
"""
Plots the data points X and y along with the decision boundary defined by theta.
"""
# Plot data points
plt.scatter(X[y==0, 0], X[y==0, 1], c='r', label='y=0')
plt.scatter(X[y==1, 0], X[y==1, 1], c='b', label='y=1')
# Define the range of x and y values
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1))
# Generate predictions for each point in the meshgrid
Z = np.dot(np.c_[xx.ravel(), yy.ravel(), np.ones(xx.ravel().shape)], theta)
Z = np.where(Z >= 0, 1, 0)
Z = Z.reshape(xx.shape)
# Plot decision boundary
plt.contour(xx, yy, Z, cmap=plt.cm.Paired)
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.show()
```
该函数接受三个参数:数据集X,标签y和模型参数theta。这里假设X是一个二维矩阵,每行代表一个数据点,第一列是第一个特征,第二列是第二个特征。y是一个向量,代表每个数据点的标签,theta是一个长度为3的向量,代表线性模型的参数。
该函数首先使用matplotlib的scatter函数绘制数据点。然后,它使用meshgrid函数生成一个二维矩阵,代表所有可能的特征值的组合。它使用np.c_函数将这个矩阵与一个全为1的列向量连接起来,并计算每个点的预测结果。最后,它使用contour函数绘制决策边界,并使用xlabel和ylabel函数设置坐标轴标签。
阅读全文