针对 MNIST 数据集,利用 Python 实现基于全连接神经网络的数据分类分析, 并对数据以及结果进行分析及可视化。包括预测值与真实值的对比图,均方误差,R方指标
时间: 2024-12-13 11:27:22 浏览: 4
基于卷积神经网络实现MNIST手写数字数据集识别应用+GUI可视化源码+数据(课程设计).zip
5星 · 资源好评率100%
在Python中,我们可以使用深度学习库如TensorFlow或PyTorch来构建全连接神经网络(Dense Neural Network)并应用于MNIST手写数字识别任务。以下是简要步骤:
1. **数据加载与预处理**:
使用`tensorflow.keras.datasets.mnist`库加载MNIST数据集,分为训练集和测试集。首先,将图像数据归一化到0-1范围,并将标签转换为one-hot编码。
```python
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
```
2. **模型构建**:
创建一个包含几个隐藏层的全连接神经网络,加上一个输出层用于10个类别(每个代表0-9的数字)。
```python
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
```
3. **模型编译**:
编译模型,指定损失函数(如交叉熵),优化器(如Adam),以及评估指标(准确率和均方误差)。
```python
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
```
4. **模型训练**:
将训练数据送入模型进行训练,设置适当的批次大小和迭代次数。
```python
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)
```
5. **模型评估与结果分析**:
- 计算测试集上的预测值与真实值的对比图(混淆矩阵),可以使用`sklearn.metrics.confusion_matrix`或`tf.keras.metrics.confusion_matrix`。
- 计算均方误差(Mean Squared Error, MSE)和R方(R^2)指标,MSE通常用于回归问题,对于分类问题需要先转化为概率预测再计算。例如,使用`sklearn.metrics.mean_squared_error`和`sklearn.metrics.r2_score`。
```python
y_pred = model.predict_classes(x_test)
mse = tf.keras.losses.sparse_categorical_crossentropy(y_test, y_pred).numpy().mean()
r2 = sklearn.metrics.r2_score(tf.argmax(y_test, axis=1), y_pred)
```
6. **可视化**:
可以绘制损失曲线(训练损失和验证损失)和精度曲线(训练精度和验证精度),以及混淆矩阵热力图。
```python
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.show()
# 绘制混淆矩阵
sns.heatmap(confusion_matrix(y_test, y_pred), annot=True)
```
以上就是基于全连接神经网络对MNIST数据集进行分类分析的基本流程,通过分析预测结果、错误分布和模型性能指标,我们可以更好地了解模型的性能。
阅读全文