用python实现使用k-means聚类算法对数据集进行聚类
时间: 2023-06-25 08:02:53 浏览: 201
python实现k-means聚类算法
好的,下面是使用Python实现k-means聚类算法对数据集进行聚类的代码:
首先,导入需要的库:
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
```
接着,读取数据集:
```python
data = pd.read_csv('data.csv')
X = data.iloc[:, [0, 1]].values
```
其中,`data.csv`是包含两列数据的CSV文件,每一行代表一个数据点。
然后,使用Elbow方法确定最佳的聚类数:
```python
wcss = []
for i in range(1, 11):
kmeans = KMeans(n_clusters=i, init='k-means++', max_iter=300, n_init=10, random_state=0)
kmeans.fit(X)
wcss.append(kmeans.inertia_)
plt.plot(range(1, 11), wcss)
plt.title('Elbow method')
plt.xlabel('Number of clusters')
plt.ylabel('WCSS')
plt.show()
```
在上述代码中,我们对1到10个聚类数进行了循环,并计算了每个聚类数下的WCSS(Within-Cluster-Sum-of-Squares)。然后,我们绘制了聚类数与WCSS之间的关系图,通过观察图像,我们可以大致确定最佳的聚类数。
最后,使用确定的聚类数进行k-means聚类:
```python
kmeans = KMeans(n_clusters=3, init='k-means++', max_iter=300, n_init=10, random_state=0)
y_kmeans = kmeans.fit_predict(X)
```
在上述代码中,我们将聚类数设为3,并使用`fit_predict()`方法对数据进行聚类,并将聚类结果赋值给`y_kmeans`。
下面是完整的代码:
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
data = pd.read_csv('data.csv')
X = data.iloc[:, [0, 1]].values
wcss = []
for i in range(1, 11):
kmeans = KMeans(n_clusters=i, init='k-means++', max_iter=300, n_init=10, random_state=0)
kmeans.fit(X)
wcss.append(kmeans.inertia_)
plt.plot(range(1, 11), wcss)
plt.title('Elbow method')
plt.xlabel('Number of clusters')
plt.ylabel('WCSS')
plt.show()
kmeans = KMeans(n_clusters=3, init='k-means++', max_iter=300, n_init=10, random_state=0)
y_kmeans = kmeans.fit_predict(X)
plt.scatter(X[y_kmeans == 0, 0], X[y_kmeans == 0, 1], s=100, c='red', label='Cluster 1')
plt.scatter(X[y_kmeans == 1, 0], X[y_kmeans == 1, 1], s=100, c='blue', label='Cluster 2')
plt.scatter(X[y_kmeans == 2, 0], X[y_kmeans == 2, 1], s=100, c='green', label='Cluster 3')
plt.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], s=300, c='yellow', label='Centroids')
plt.title('Clusters')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.show()
```
其中,`data.csv`文件的内容如下:
```
X,Y
2,3
2,4
3,4
6,6
7,5
7,7
```
最后,我们还绘制了聚类结果的散点图。这里,我们使用不同的颜色表示不同的聚类,使用黄色的大点表示每个聚类的中心点。
阅读全文