kmeans代码python
时间: 2023-10-08 22:14:07 浏览: 68
大数据的作业 基于python实现kmeans 是对球队强弱进行聚类+源代码+文档说明+数据
5星 · 资源好评率100%
以下是一个简单的 KMeans 聚类算法的 Python 代码实现:
```python
import numpy as np
import matplotlib.pyplot as plt
class KMeans:
def __init__(self, k=2, max_iter=100, plot_steps=False):
self.k = k
self.max_iter = max_iter
self.plot_steps = plot_steps
def fit(self, X):
self.n_samples, self.n_features = X.shape
self.centroids = np.random.randn(self.k, self.n_features)
self.initial_centroids = self.centroids.copy()
self.cluster_assignments = np.zeros(self.n_samples)
for i in range(self.max_iter):
self._assign_clusters(X)
if self.plot_steps:
self.plot(X)
old_centroids = self.centroids.copy()
self._update_centroids(X)
if np.allclose(self.centroids, old_centroids):
break
return self.get_cluster_labels(X)
def _assign_clusters(self, X):
for i, x in enumerate(X):
distances = np.sqrt(np.sum((self.centroids - x) ** 2, axis=1))
self.cluster_assignments[i] = np.argmin(distances)
def _update_centroids(self, X):
for i in range(self.k):
self.centroids[i] = np.mean(X[self.cluster_assignments == i], axis=0)
def get_cluster_labels(self, X):
labels = np.zeros(self.n_samples)
for i, x in enumerate(X):
distances = np.sqrt(np.sum((self.centroids - x) ** 2, axis=1))
labels[i] = np.argmin(distances)
return labels
def plot(self, X):
fig, ax = plt.subplots(figsize=(12, 8))
for i in range(self.k):
points = X[self.cluster_assignments == i]
ax.scatter(points[:, 0], points[:, 1], s=50)
ax.scatter(self.centroids[:, 0], self.centroids[:, 1], marker='*', s=200, c='#050505')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.show()
```
使用该代码进行 KMeans 聚类的例子:
```python
from sklearn.datasets import make_blobs
X, y = make_blobs(centers=3, n_samples=500, random_state=42)
kmeans = KMeans(k=3, max_iter=150, plot_steps=True)
y_pred = kmeans.fit(X)
clusters = np.unique(y_pred)
for cluster in clusters:
row_ix = np.where(y_pred == cluster)
plt.scatter(X[row_ix, 0], X[row_ix, 1])
plt.show()
```
运行结果会显示如下聚类图像:
![KMeans聚类图像](https://img-blog.csdnimg.cn/20210913150054169.png)
阅读全文