写一个线性判别分析(Linear Discriminant Analysis的图像python代码
时间: 2024-01-21 20:01:59 浏览: 67
LDA-线性判别分析代码
3星 · 编辑精心推荐
好的,下面是一个简单的线性判别分析的 Python 代码,用于对二维数据进行分类并绘制决策边界的图像:
```python
import numpy as np
import matplotlib.pyplot as plt
class LinearDiscriminantAnalysis():
def __init__(self):
self.w = None
def fit(self, X, y):
n_features = X.shape[1]
class_labels = np.unique(y)
mean_overall = np.mean(X, axis=0)
Sw = np.zeros((n_features, n_features))
Sb = np.zeros((n_features, n_features))
for c in class_labels:
X_c = X[y == c]
mean_c = np.mean(X_c, axis=0)
Sw += (X_c - mean_c).T.dot(X_c - mean_c)
n_c = X_c.shape[0]
mean_diff = (mean_c - mean_overall).reshape(n_features, 1)
Sb += n_c * (mean_diff).dot(mean_diff.T)
self.w = np.linalg.inv(Sw).dot(Sb).dot(mean_diff)
def predict(self, X):
projection = X.dot(self.w)
return projection
def plot_decision_boundary(self, X, y):
x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, 0.1),
np.arange(x2_min, x2_max, 0.1))
Z = self.predict(np.c_[xx1.ravel(), xx2.ravel()])
Z = Z.reshape(xx1.shape)
plt.contourf(xx1, xx2, Z, alpha=0.4)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()
```
其中,plot_decision_boundary 函数用于绘制决策边界,传入的参数 X 是一个二维的 numpy 数组,表示特征矩阵;y 是一个一维的 numpy 数组,表示每个样本对应的标签。该函数会将决策边界绘制在特征平面上,同时将样本点按照类别用不同的颜色进行标记。
阅读全文