线性判别分析分类一维光谱数据的代码
时间: 2023-07-05 13:08:07 浏览: 72
LDA-线性判别分析代码
3星 · 编辑精心推荐
以下是使用线性判别分析(LDA)分类一维光谱数据的 Python 代码示例:
```python
import numpy as np
import matplotlib.pyplot as plt
# 创建一维光谱数据
x1 = np.random.normal(0, 1, 50)
x2 = np.random.normal(4, 1, 50)
# 绘制数据分布
plt.scatter(x1, np.zeros_like(x1), label='class 1')
plt.scatter(x2, np.zeros_like(x2), label='class 2')
plt.legend()
plt.show()
# 计算均值和协方差矩阵
mean1 = np.mean(x1)
mean2 = np.mean(x2)
cov = np.cov(np.vstack((x1, x2)))
# 计算投影方向
w = np.linalg.inv(cov).dot(mean2 - mean1)
# 投影数据
y1 = w.T.dot(x1)
y2 = w.T.dot(x2)
# 绘制投影后的数据
plt.scatter(y1, np.zeros_like(y1), label='class 1')
plt.scatter(y2, np.zeros_like(y2), label='class 2')
plt.legend()
plt.show()
# 计算分类阈值
threshold = (mean1 + mean2) / 2
# 预测新样本类别
x_new = 2
y_new = w.T.dot(x_new)
if y_new >= threshold:
print('class 2')
else:
print('class 1')
```
该代码首先生成了两个高斯分布的一维数据,然后计算了它们的均值和协方差矩阵,通过计算协方差矩阵的逆矩阵与两个类别的均值之差的乘积来得到投影方向。然后,将原始数据投影到该方向上,并计算分类阈值。最后,用一个新的样本进行预测并输出其类别。
阅读全文