scipy.cluster.kmeans聚类,再计算每一类的重心
时间: 2024-01-12 21:03:55 浏览: 25
scipy.cluster.kmeans函数可以用于进行聚类分析,它可以将一组数据分成k个类,每个类中的数据尽可能相似。
例如,可以使用以下代码将数据分成3个类:
```python
import numpy as np
from scipy.cluster.vq import kmeans
# 生成数据
data = np.random.rand(100, 2)
# 聚类,分成3个类
centroids, labels = kmeans(data, 3)
# 打印每个数据点所属的类别
print(labels)
```
在得到每个数据点所属的类别之后,可以使用以下代码计算每个类别的重心:
```python
# 计算每个类别的重心
for i in range(3):
cluster = data[labels == i]
centroid = np.mean(cluster, axis=0)
print(f"Cluster {i} centroid: {centroid}")
```
这段代码中,我们首先使用`labels == i`选出第i个类别的数据点,然后使用`np.mean`函数计算这些数据点的平均值,即为该类别的重心。
相关问题
# 聚类数据 import numpy as np pos = np.arange(0, 6, 5) # 两个中心点在x=0,x=5 size = 50 # 生成两个簇 x = np.array([ np.random.normal(i+np.random.random(), 1, size) for i in pos ]).flatten() y = np.array([ np.random.normal(i+np.random.random(), 1, size) for i in pos ]).flatten() print(x.shape, y.shape)编写Python程序,实现以下要求: 画出上面数据的散点图 采用scipy.cluster.vq.kmeans()聚类,并在散点图中画出聚类中心点
以下是实现要求的Python代码:
```python
import numpy as np
from scipy.cluster.vq import kmeans
import matplotlib.pyplot as plt
pos = np.arange(0, 6, 5) # 两个中心点在x=0,x=5
size = 50 # 生成两个簇
x = np.array([np.random.normal(i+np.random.random(), 1, size) for i in pos]).flatten()
y = np.array([np.random.normal(i+np.random.random(), 1, size) for i in pos]).flatten()
# 画出散点图
plt.scatter(x, y)
# 聚类
data = np.array(list(zip(x, y)))
centroids, _ = kmeans(data, 2)
# 画出聚类中心点
plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', s=200, linewidths=3, color='r')
plt.show()
```
运行程序后,会弹出一个散点图窗口,其中红色叉号表示聚类中心点。
kmeans聚类算法树状图python实现
要使用K-means算法实现聚类树状图,可以按照以下步骤进行:
1. 导入所需的库:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from scipy.cluster.hierarchy import dendrogram, linkage
```
2. 准备数据集:
```python
X = np.array([[1, 2], [1, 4], [1, 0], [4, 2], [4, 4], [4, 0]])
```
3. 使用K-means算法进行聚类:
```python
kmeans = KMeans(n_clusters=2)
means.fit(X)
```
4. 获取聚类标签:
```python
labels = kmeans.labels_
```
5. 使用层次聚类算法生成树状图:
```python
Z = linkage(X, method='single')
dn = dendrogram(Z)
plt.show()
```
完整代码如下:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from scipy.cluster.hierarchy import dendrogram, linkage
X = np.array([[1, 2], [1, 4], [1, 0], [4, 2], [4, 4], [4, 0]])
kmeans = KMeans(n_clusters=2)
kmeans.fit(X)
labels = kmeans.labels_
Z = linkage(X, method='single')
dn = dendrogram(Z)
plt.show()
```