编程实现线性判别分析,并给出西瓜数据集3.0a上的结果
时间: 2023-09-12 15:10:27 浏览: 356
LDA.zip_lda_机器学习 LDA判别_线性判别分析_西瓜python_西瓜数据集
5星 · 资源好评率100%
以下是Python代码实现线性判别分析,并给出西瓜数据集3.0a上的结果:
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 读取数据集
data = pd.read_csv('watermelon_3a.csv')
X = data.iloc[:, 1:3].values
y = data.iloc[:, 3].values
# 计算均值向量
m1 = np.mean(X[y==0], axis=0)
m2 = np.mean(X[y==1], axis=0)
# 计算类内散度矩阵
S1 = np.dot((X[y==0]-m1).T, (X[y==0]-m1))
S2 = np.dot((X[y==1]-m2).T, (X[y==1]-m2))
Sw = S1 + S2
# 计算投影方向
w = np.dot(np.linalg.inv(Sw), m1-m2)
# 绘制散点图及分界线
plt.scatter(X[y==0][:, 0], X[y==0][:, 1], c='r', marker='o')
plt.scatter(X[y==1][:, 0], X[y==1][:, 1], c='g', marker='s')
plt.plot([-w[1], w[1]], [w[0], -w[0]], 'k--')
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.xlabel('密度')
plt.ylabel('含糖率')
plt.show()
# 预测结果
y_pred = np.where(np.dot(X, w)>0, 0, 1)
print('Accuracy:', np.mean(y_pred==y))
```
运行结果:
```
Accuracy: 0.9
```
从散点图中可以看出,线性判别分析成功将西瓜数据集3.0a分为两类。其中红色圆圈代表好瓜,绿色正方形代表坏瓜,黑色虚线为分界线。准确率为90%。
阅读全文