import pandas as pd import numpy as np import matplotlib.pyplot as plt import math data = np.load('dataset_em.npy') mu0 = np.array([np.mean(data) - 1, np.mean(data) + 1]) sg0 = np.array([1.0, 1.0]) w0 = [0.5, 0.5] ary_mu = [mu0] ary_sg = [sg0] ary_w = [w0] x = np.arange(-10, 15, 0.1) def pdf_gauss(x, mu, sg): c = 1 / np.sqrt(2 * np.pi) p = c / sg * np.exp(-0.5 * ((x - mu) / sg) ** 2) return p def gauss_mix_pdf(w, x, mu, sg): p = w[0] * pdf_gauss(x, mu[0], sg[0]) + w[1] * pdf_gauss(x, mu[1], sg[1]) return p for j in range(0, 100): mu = ary_mu[j] sg = ary_sg[j] w = ary_w[j] gama = [] for i in (0, 1): b = w[i] * pdf_gauss(data, mu[i], sg[i]) c = gauss_mix_pdf(w, data, mu, sg) gama.append(b / c) gama_sum0 = np.sum(gama[0]) gama_sum1 = np.sum(gama[1]) new_mu = [np.sum(data * gama[0]) / gama_sum0, np.sum(data * gama[1]) / gama_sum1] new_sg = [np.sqrt(np.sum(gama[0] * np.power(data - new_mu[0], 2)) / gama_sum0), np.sqrt(np.sum(gama[1] * np.power(data - new_mu[1], 2)) / gama_sum1)] # 第二个数太大了 new_w = [gama_sum0 / data.shape[0], gama_sum1 / data.shape[0]] ary_mu.extend([new_mu]) ary_sg.extend([new_sg]) ary_w.extend([new_w]) del gama y = gauss_mix_pdf(ary_w[100], x, ary_mu[100], ary_sg[100]) plt.figure() plt.hist(data, bins=100, facecolor="blue", edgecolor="black", alpha=0.7, density=True) plt.plot(x, y) plt.show() 请逐句详细解释这段代码
时间: 2023-11-27 12:51:27 浏览: 111
这段代码实现了一个高斯混合模型(Gaussian Mixture Model, GMM)的 EM 算法,用来对给定的数据进行聚类。下面是每一行代码的详细解释:
```python
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
```
导入所需要的库。
```python
data = np.load('dataset_em.npy')
```
从文件中加载数据。
```python
mu0 = np.array([np.mean(data) - 1, np.mean(data) + 1])
sg0 = np.array([1.0, 1.0])
w0 = [0.5, 0.5]
```
初始化高斯混合模型的均值(mu)、标准差(sg)和权重(w)。
```python
ary_mu = [mu0]
ary_sg = [sg0]
ary_w = [w0]
```
将初始值存储在数组中,以便后续的 EM 算法迭代更新。
```python
x = np.arange(-10, 15, 0.1)
```
生成 X 轴的数据,用于画图。
```python
def pdf_gauss(x, mu, sg):
c = 1 / np.sqrt(2 * np.pi)
p = c / sg * np.exp(-0.5 * ((x - mu) / sg) ** 2)
return p
```
定义高斯分布的概率密度函数。
```python
def gauss_mix_pdf(w, x, mu, sg):
p = w[0] * pdf_gauss(x, mu[0], sg[0]) + w[1] * pdf_gauss(x, mu[1], sg[1])
return p
```
定义高斯混合分布的概率密度函数,其中 w 为权重,x、mu、sg 分别为 X 轴、均值和标准差。
```python
for j in range(0, 100):
mu = ary_mu[j]
sg = ary_sg[j]
w = ary_w[j]
gama = []
for i in (0, 1):
b = w[i] * pdf_gauss(data, mu[i], sg[i])
c = gauss_mix_pdf(w, data, mu, sg)
gama.append(b / c)
gama_sum0 = np.sum(gama[0])
gama_sum1 = np.sum(gama[1])
new_mu = [np.sum(data * gama[0]) / gama_sum0, np.sum(data * gama[1]) / gama_sum1]
new_sg = [np.sqrt(np.sum(gama[0] * np.power(data - new_mu[0], 2)) / gama_sum0),
np.sqrt(np.sum(gama[1] * np.power(data - new_mu[1], 2)) / gama_sum1)]
new_w = [gama_sum0 / data.shape[0], gama_sum1 / data.shape[0]]
ary_mu.extend([new_mu])
ary_sg.extend([new_sg])
ary_w.extend([new_w])
del gama
```
进行 EM 算法的迭代更新,其中 j 为迭代次数,data 为输入数据,mu、sg、w 分别为均值、标准差和权重。gama 为 E 步中计算的后验概率,new_mu、new_sg、new_w 分别为 M 步中更新后的均值、标准差和权重。最后将更新后的均值、标准差和权重存储在数组中。
```python
y = gauss_mix_pdf(ary_w[100], x, ary_mu[100], ary_sg[100])
plt.figure()
plt.hist(data, bins=100, facecolor="blue", edgecolor="black", alpha=0.7, density=True)
plt.plot(x, y)
plt.show()
```
画出最终的 GMM 模型和数据的分布直方图。其中 y 为高斯混合分布的概率密度函数,ary_w[100]、ary_mu[100]、ary_sg[100] 分别为最终的权重、均值和标准差。plt.hist() 画出数据的分布直方图,plt.plot() 画出 GMM 模型的曲线。
阅读全文