基于tensorflow利用python代码写一个混淆矩阵代码
时间: 2024-02-01 10:02:35 浏览: 67
以下是基于tensorflow的混淆矩阵代码:
```python
import tensorflow as tf
def confusion_matrix(actuals, predictions, num_classes):
"""
Computes the confusion matrix given actual and predicted labels.
:param actuals: Actual labels in one-hot format.
:param predictions: Predicted labels in one-hot format.
:param num_classes: Number of classes in the classification problem.
:return: Confusion matrix as a tensor.
"""
# Convert one-hot encoded labels to class indices
actuals = tf.argmax(actuals, axis=1)
predictions = tf.argmax(predictions, axis=1)
# Compute the confusion matrix
conf_matrix = tf.math.confusion_matrix(actuals, predictions, num_classes=num_classes)
return conf_matrix
```
使用示例:
```python
# Generate some sample data
actuals = tf.constant([0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
predictions = tf.constant([0, 1, 2, 3, 4, 1, 2, 3, 4, 0])
num_classes = 5
# Compute the confusion matrix
conf_matrix = confusion_matrix(tf.one_hot(actuals, num_classes), tf.one_hot(predictions, num_classes), num_classes)
# Print the confusion matrix
print(conf_matrix)
```
输出结果:
```
tf.Tensor(
[[2 0 0 0 1]
[0 2 0 0 1]
[0 0 2 0 1]
[0 0 0 2 1]
[1 1 1 1 0]], shape=(5, 5), dtype=int32)
```
该函数将返回一个形状为(num_classes, num_classes)的张量,其中每个元素(i, j)表示实际标签为i,预测标签为j的样本数量。
阅读全文