python多分类混淆矩阵代码_keras中用于多类分类的混淆矩阵生成
时间: 2023-07-21 17:11:48 浏览: 219
plot_confusion_keras_混淆矩阵_plotconfusion_
5星 · 资源好评率100%
以下是使用 Keras 生成多分类混淆矩阵的示例代码:
```python
import numpy as np
from sklearn.metrics import confusion_matrix
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils
# 生成随机数据,共有4个类别
X_train = np.random.rand(1000, 10)
y_train = np.random.randint(0, 4, 1000)
X_test = np.random.rand(200, 10)
y_test = np.random.randint(0, 4, 200)
# 将标签转换为one-hot编码
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
# 定义模型
model = Sequential()
model.add(Dense(32, input_dim=10, activation='relu'))
model.add(Dense(4, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# 训练模型
model.fit(X_train, y_train, epochs=10, batch_size=32, verbose=0)
# 预测测试集结果
y_pred = model.predict_classes(X_test)
# 生成混淆矩阵
cm = confusion_matrix(np.argmax(y_test, axis=1), y_pred)
print(cm)
```
解释:
1. 生成了一个随机数据集,共有4个类别。
2. 将标签转换为one-hot编码。
3. 定义了一个有2个全连接层的神经网络,输出层使用softmax激活函数。
4. 训练模型。
5. 预测测试集结果并生成混淆矩阵。其中,np.argmax(y_test, axis=1)将one-hot编码的标签转换为原始标签。
阅读全文