nmi计算函数
时间: 2023-11-11 17:16:36 浏览: 61
NMI(Normalized Mutual Information)是一种用于评估聚类结果的指标,它可以衡量聚类结果与真实标签之间的相似度。计算NMI的函数如下:
```python
import numpy as np
def nmi(labels_true, labels_pred):
"""
计算 NMI(Normalized Mutual Information)指标
:param labels_true: 真实标签
:param labels_pred: 预测标签
:return: NMI指标值
"""
n = len(labels_true)
classes = np.unique(labels_true) # 真实标签的类别数
clusters = np.unique(labels_pred) # 预测标签的类别数
class_size = np.zeros(len(classes)) # 每个类别的样本数
cluster_size = np.zeros(len(clusters)) # 每个聚类的样本数
class_cluster = np.zeros((len(classes), len(clusters))) # 每个类别和聚类的交集数量
for i in range(n):
class_idx = np.argwhere(classes == labels_true[i])[0][0] # 真实标签的类别索引
cluster_idx = np.argwhere(clusters == labels_pred[i])[0][0] # 预测标签的聚类索引
class_size[class_idx] += 1
cluster_size[cluster_idx] += 1
class_cluster[class_idx, cluster_idx] += 1
mi = 0.0
eps = np.finfo(float).eps
for i in range(len(classes)):
for j in range(len(clusters)):
if class_cluster[i, j] != 0:
mi += (class_cluster[i, j] / n) * np.log((n * class_cluster[i, j] + eps) / (class_size[i] * cluster_size[j]))
h_true = -np.sum((class_size / n) * np.log(class_size / n + eps))
h_pred = -np.sum((cluster_size / n) * np.log(cluster_size / n + eps))
denominator = np.sqrt(h_true * h_pred)
nmi = mi / denominator
return nmi
```
其中,`labels_true`是真实的标签,`labels_pred`是聚类算法预测出的标签;`classes`和`clusters`分别表示真实标签和预测标签中的类别数;`class_size`和`cluster_size`分别表示每个类别和聚类中的样本数量;`class_cluster`表示每个类别和聚类的交集数量。最后,根据公式计算出NMI指标值并返回。