纯python代码DBSCAN算法分类鸢尾花数据
时间: 2023-09-16 19:14:54 浏览: 171
好的,以下是纯Python代码实现DBSCAN算法对鸢尾花数据进行聚类分类:
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 加载数据集
def load_dataset():
iris = pd.read_csv('iris.csv')
X = iris.iloc[:, [0, 1, 2, 3]].values
return X
# 计算距离矩阵
def dist(a, b):
return np.sqrt(np.sum((a - b) ** 2))
# DBSCAN算法实现
def dbscan(X, eps, min_pts):
m = X.shape[0]
visited = np.zeros(m)
clusters = np.zeros(m)
C = 0
for i in range(m):
if not visited[i]:
visited[i] = 1
N = get_neighbors(X, i, eps)
if len(N) < min_pts:
clusters[i] = -1
else:
C += 1
clusters[i] = C
expand_cluster(X, visited, clusters, i, N, C, eps, min_pts)
return clusters
# 获取邻居点
def get_neighbors(X, i, eps):
return np.where(np.array([dist(X[i], x) for x in X]) <= eps)[0]
# 扩展簇
def expand_cluster(X, visited, clusters, i, N, C, eps, min_pts):
for j in N:
if not visited[j]:
visited[j] = 1
N2 = get_neighbors(X, j, eps)
if len(N2) >= min_pts:
N = np.union1d(N, N2)
if not clusters[j]:
clusters[j] = C
# 可视化结果
def plot_result(X, clusters):
colors = ['r', 'g', 'b', 'y', 'c', 'm']
for i in range(len(clusters)):
if clusters[i] == -1:
plt.scatter(X[i][0], X[i][1], color='k', marker='.')
else:
plt.scatter(X[i][0], X[i][1], color=colors[int(clusters[i]) % len(colors)], marker='.')
plt.show()
# 主函数
if __name__ == '__main__':
X = load_dataset()
eps = 0.5
min_pts = 5
clusters = dbscan(X, eps, min_pts)
plot_result(X, clusters)
```
其中,`load_dataset()`函数用于加载鸢尾花数据集,`dist()`函数用于计算两个样本之间的距离,`get_neighbors()`函数用于获取邻居点,`expand_cluster()`函数用于扩展簇,`dbscan()`函数是DBSCAN算法的主要实现,`plot_result()`函数用于可视化聚类结果。
在主函数中,我们首先加载鸢尾花数据集,然后定义`eps`和`min_pts`两个参数,这两个参数分别表示半径和最小点数,用于确定核心对象和边界点。最后,我们调用`dbscan()`函数进行聚类,并利用`plot_result()`函数将聚类结果可视化出来。
阅读全文