plot_decision_boundary函数
时间: 2024-05-04 08:18:52 浏览: 184
Plot Ellipses:绘制椭圆-matlab开发
plot_decision_boundary函数是一个用于绘制分类边界的函数,通常用于可视化分类模型的效果。它的实现过程如下:
1. 首先,我们需要定义一个二维平面上的网格,以便在上面绘制分类边界。我们可以使用numpy中的meshgrid函数来创建这个网格。
2. 接下来,我们需要将这个网格输入到我们的分类模型中,以获取每个点的分类结果。这个过程可以使用predict函数来完成。
3. 然后,我们可以使用matplotlib库来绘制这个分类边界。具体来说,我们可以使用contourf函数来绘制等高线图,并使用颜色来表示不同的分类结果。
下面是一个简单的示例代码,展示了如何使用plot_decision_boundary函数绘制一个线性分类模型的分类边界:
```
import numpy as np
import matplotlib.pyplot as plt
def plot_decision_boundary(model, X, y):
# Step 1: Create a meshgrid
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
# Step 2: Predict the labels for all points in the meshgrid
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# Step 3: Plot the decision boundary
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
# Example usage
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
X, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
random_state=1, n_clusters_per_class=1)
model = LogisticRegression()
model.fit(X, y)
plot_decision_boundary(model, X, y)
plt.show()
```
阅读全文