写一个关于客户聚类的LVQ算法的python代码,并进行可视化分析
时间: 2024-03-24 16:41:26 浏览: 16
好的,以下是一个关于客户聚类的LVQ算法的Python代码,并结合Matplotlib进行可视化分析:
```python
import numpy as np
import random
import matplotlib.pyplot as plt
class LVQ:
def __init__(self, k, alpha, epochs):
self.k = k # number of clusters
self.alpha = alpha # learning rate
self.epochs = epochs # number of epochs
self.weights = None # weight vectors
self.labels = None # labels of the weight vectors
def train(self, X, y):
# initialize weight vectors randomly
self.weights = np.random.rand(self.k, X.shape[1])
# assign labels to the weight vectors
self.labels = np.zeros(self.k)
for i in range(self.k):
self.labels[i] = random.choice(y)
# train for specified number of epochs
for epoch in range(self.epochs):
# adjust learning rate
alpha = self.alpha * (1.0 - epoch / float(self.epochs))
# loop through all data points
for i in range(X.shape[0]):
# find closest weight vector
distances = np.linalg.norm(self.weights - X[i], axis=1)
closest = np.argmin(distances)
# update closest weight vector
if self.labels[closest] == y[i]:
self.weights[closest] += alpha * (X[i] - self.weights[closest])
else:
self.weights[closest] -= alpha * (X[i] - self.weights[closest])
def predict(self, X):
y_pred = np.zeros(X.shape[0])
for i in range(X.shape[0]):
distances = np.linalg.norm(self.weights - X[i], axis=1)
closest = np.argmin(distances)
y_pred[i] = self.labels[closest]
return y_pred
def plot_clusters(X, y):
# plot data points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis')
# plot cluster centers
plt.scatter(lvq.weights[:, 0], lvq.weights[:, 1], marker='x', color='red', s=100, linewidth=2)
plt.title('Customer Clustering with LVQ')
plt.xlabel('Age')
plt.ylabel('Income')
plt.show()
# generate sample data
np.random.seed(0)
X = np.random.randn(200, 2)
y = np.random.randint(0, 2, 200)
# create LVQ model
lvq = LVQ(k=2, alpha=0.1, epochs=100)
# train LVQ model
lvq.train(X, y)
# predict cluster labels
y_pred = lvq.predict(X)
# plot clusters
plot_clusters(X, y_pred)
```
上面的代码中,我们首先生成了一个包含200个样本的数据集,每个样本包含2个特征(年龄和收入),并随机将这些样本分为2个类别。然后,我们创建了一个LVQ模型,并训练它对数据进行聚类。最后,我们预测每个样本所属的聚类,并使用Matplotlib将聚类可视化。
运行上面的代码后,你将看到一个包含两个类别的散点图,其中红色的叉表示聚类的中心点。在这个例子中,我们使用了2个聚类,但是你可以根据自己的数据集和需求来选择不同的聚类数量。