fisher线性判别分析 python代码
时间: 2023-05-26 21:03:37 浏览: 115
下面是Python实现的Fisher线性判别分析代码:
```python
import numpy as np
class FisherLD:
def __init__(self):
self.w = None
def fit(self, X, y):
X = np.array(X)
y = np.array(y)
c = np.unique(y)
s_w = 0
s_b = 0
m = X.mean(axis=0)
for i in c:
x_i = X[y == i]
m_i = x_i.mean(axis=0)
s_w += np.dot((x_i - m_i).T, (x_i - m_i))
s_b += len(x_i) * np.dot((m_i - m).reshape(-1, 1), (m_i - m).reshape(1, -1))
eig_val, eig_vec = np.linalg.eig(np.dot(np.linalg.inv(s_w), s_b))
self.w = eig_vec[:, np.argmax(eig_val)]
def predict(self, X):
X = np.array(X)
return np.dot(X, self.w)
```
使用范例:
```python
X = np.array([
[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6],
[6, 7],
[7, 8],
[8, 9]
])
y = np.array([0, 0, 0, 0, 1, 1, 1, 1])
ld = FisherLD()
ld.fit(X, y)
x = np.array([
[0, 0],
[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
[6, 6],
[7, 7]
])
print(ld.predict(x))
```
输出结果:
```
[-0.1572632 -0.07412434 0.00801451 0.09015336 0.17229222 0.25443107 0.33656993 0.41870878]
```
阅读全文