混合高斯贝叶斯模型代码
时间: 2023-08-24 22:08:22 浏览: 165
混合高斯贝叶斯模型的代码如下所示:
```python
import numpy as np
from scipy.stats import multivariate_normal
class GaussianMixtureBayes:
def __init__(self, k, max_iter=100, eps=1e-6):
self.k = k
self.max_iter = max_iter
self.eps = eps
def fit(self, X):
n, d = X.shape
# 初始化每个高斯分布的均值、协方差矩阵和权重
means = np.random.randn(self.k, d)
covs = np.array([np.eye(d)] * self.k)
weights = np.ones(self.k) / self.k
for _ in range(self.max_iter):
# E 步,计算每个样本属于每个高斯分布的概率
probs = np.zeros((n, self.k))
for i in range(self.k):
probs[:, i] = weights[i] * multivariate_normal.pdf(X, means[i], covs[i])
probs /= probs.sum(axis=1, keepdims=True)
# M 步,更新每个高斯分布的均值、协方差矩阵和权重
for i in range(self.k):
weight_sum = probs[:, i].sum()
means[i] = (probs[:, i] @ X) / weight_sum
covs[i] = ((probs[:, i] * (X - means[i]).T) @ (X - means[i])) / weight_sum
weights[i] = weight_sum / n
# 计算对数似然函数值
log_likelihood = np.log(probs @ weights).sum()
if np.abs(log_likelihood - self.log_likelihood) < self.eps:
break
self.log_likelihood = log_likelihood
self.means = means
self.covs = covs
self.weights = weights
def predict(self, X):
probs = np.zeros((X.shape[0], self.k))
for i in range(self.k):
probs[:, i] = self.weights[i] * multivariate_normal.pdf(X, self.means[i], self.covs[i])
return np.argmax(probs, axis=1)
```
在这里,我们使用 `scipy.stats` 中的 `multivariate_normal` 函数来计算多元高斯分布的概率密度函数。`fit` 方法用来训练模型,`predict` 方法用来预测新样本的类别。注意,这个代码仅适用于连续型变量。如果数据集同时包含离散和连续型变量,需要对每个分量采用适当的概率分布进行建模。
阅读全文