卷积神经网络 tensorflow生成混淆矩阵ROC曲线的库
时间: 2023-11-13 18:02:10 浏览: 150
通过matlab绘制数据的ROC曲线,并输出TN,FP,FN,TP混淆矩阵,含仿真操作录像
5星 · 资源好评率100%
在 TensorFlow 中,可以使用 tf.math.confusion_matrix() 函数来生成混淆矩阵。ROC曲线的生成可以使用 sklearn.metrics中的roc_curve()函数。
下面是一个简单的示例代码:
```python
import tensorflow as tf
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
# 加载数据集并进行预处理
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = tf.keras.applications.resnet50.preprocess_input(x_train)
x_test = tf.keras.applications.resnet50.preprocess_input(x_test)
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
# 构建模型
model = tf.keras.applications.ResNet50(weights=None, classes=10)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))
# 生成混淆矩阵
y_pred = model.predict(x_test)
cm = tf.math.confusion_matrix(tf.argmax(y_test, axis=1), tf.argmax(y_pred, axis=1))
# 绘制ROC曲线
fpr, tpr, _ = roc_curve(y_test.ravel(), y_pred.ravel())
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.show()
```
这个示例代码中使用了 ResNet50 模型来进行图像分类,并使用了 CIFAR10 数据集。训练完模型后,使用 tf.math.confusion_matrix() 函数来生成混淆矩阵,然后使用 sklearn.metrics 中的 roc_curve() 函数来生成 ROC 曲线。最后使用 matplotlib 库来绘制 ROC 曲线。
阅读全文