def fit(self, data): """process of gratitude dereasing of params in GMM """ self.init_params(data) for e in range(self.n_epochs): ## e-step: 计算后验 posterior = self.e_step() ## m-step: 计算梯度,并更新参数 for cls, (grad_class, grad_mu, grad_sigma) in \ zip(range(self.n_class), self.m_step(posterior)): self.class_prob[cls] += 1e-3 *grad_class self.mus[cls] += 1e-3 * grad_mu self.vars[cls] += 1e-3 * grad_sigma self.class_prob /= self.class_prob.sum() print (e)
时间: 2024-02-10 12:32:41 浏览: 91
这是一个 GMM(高斯混合模型)的训练过程,其中包含了 E 步和 M 步。E 步计算每个样本属于每个混合成分的后验概率,M 步计算梯度并更新模型参数。具体来说,这个 fit 函数接受一个数据集作为输入,然后初始化 GMM 的参数。接着进行 n_epochs 轮迭代,每轮迭代分为 E 步和 M 步。在 E 步中,计算每个样本属于每个混合成分的后验概率。在 M 步中,计算每个混合成分的梯度,并更新混合成分的权重、均值和方差。这里的更新采用了梯度下降法,学习率为 1e-3。最后,归一化混合成分的权重,并输出当前迭代轮数。
相关问题
ef m_step(self, posterior): """Maximization step in EM algorithm, use last time posterior p(z|x) to calculate params gratitude. Args: posterior: [n_sample, n_class] p(z=i | x_i, \theta_t) Return: Each class param's gratitude in current time step grad_class_prob: scatter of class j grad_mus: [,dim] jth class mus grad_sigma: [, dim, dim] jth class sigma """ for cls in range(self.n_class): ## class_prob gratitudes grad_class_prob = posterior[:, cls].sum() / self.n_sample ## mu_j <- (\sum_i p(z_j|x_i) * x_i) / sum_i p(z_j |x_i) grad_mus = np.zeros(self.n_dim) for ind in range(self.n_sample): grad_mus += posterior[ind, cls] * self.data[ind, :] grad_mus /= posterior[:, cls].sum() ## sigma_j <- (\sum_i p(z_j|x_i) * (x_i - \mu_j)^2) / sum_i p(z_j |x_i) grad_sigma = np.zeros((self.n_dim, self.n_dim)) for ind in range(self.n_sample): grad_sigma += posterior[ind, cls] * \ np.dot((self.data[ind, :] - self.mus[cls]), self.data[ind, :] - self.mus[cls].T) grad_sigma /= posterior[:, cls].sum() yield grad_class_prob, grad_mus, grad_sigma 这段代码的作用
这段代码是高斯混合模型的EM算法中的M步骤,即最大化步骤,用上一步得到的后验概率p(z|x)来计算模型的参数(即类别概率、均值和协方差矩阵)的梯度。
具体来说,这段代码中的posterior是一个n_samples x n_components的矩阵,表示每个样本属于每个聚类的后验概率。grad_class_prob、grad_mus和grad_sigma分别表示类别概率、均值和协方差矩阵的梯度,通过yield关键字返回结果。
代码的第一个循环计算类别概率的梯度,即将某个聚类中所有样本的后验概率相加后除以样本总数,得到该聚类的类别概率的梯度。
代码的第二个循环计算均值的梯度,即将某个聚类中所有样本的后验概率乘以该样本的特征向量相加,然后除以该聚类所有样本的后验概率之和,得到该聚类的均值的梯度。
代码的第三个循环计算协方差矩阵的梯度,即将某个聚类中所有样本的后验概率乘以该样本与该聚类均值之差的矩阵相乘,然后加权求和,最后除以该聚类所有样本的后验概率之和,得到该聚类的协方差矩阵的梯度。
最后,通过yield返回三个梯度,用于更新EM算法中的参数。
阅读全文