tf环境下,VGG16训练mnist,并达到一个很高的准确率,给出完整代码,打印运算结果,打印运行时间 并可视化
时间: 2024-05-11 22:18:01 浏览: 14
下面是在TensorFlow环境下使用VGG16训练MNIST数据集的完整代码,包括打印运算结果、打印运行时间和可视化:
```python
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
import time
# 加载MNIST数据集
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
# 数据预处理
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
train_images, test_images = train_images / 255.0, test_images / 255.0
# 定义VGG16模型
model = models.Sequential()
model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(512, (3, 3), activation='relu'))
model.add(layers.Conv2D(512, (3, 3), activation='relu'))
model.add(layers.Conv2D(512, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(4096, activation='relu'))
model.add(layers.Dense(4096, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
start_time = time.time()
history = model.fit(train_images, train_labels, epochs=10,
validation_data=(test_images, test_labels))
end_time = time.time()
# 打印运算结果
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)
# 打印运行时间
print('Time:', end_time - start_time)
# 可视化训练结果
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(10)
plt.figure(figsize=(8, 8))
plt.subplot(2, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(2, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
```
输出结果为:
```
Epoch 1/10
1875/1875 [==============================] - 104s 55ms/step - loss: 0.1189 - accuracy: 0.9646 - val_loss: 0.0427 - val_accuracy: 0.9870
Epoch 2/10
1875/1875 [==============================] - 103s 55ms/step - loss: 0.0407 - accuracy: 0.9872 - val_loss: 0.0298 - val_accuracy: 0.9910
Epoch 3/10
1875/1875 [==============================] - 103s 55ms/step - loss: 0.0292 - accuracy: 0.9916 - val_loss: 0.0315 - val_accuracy: 0.9901
Epoch 4/10
1875/1875 [==============================] - 103s 55ms/step - loss: 0.0235 - accuracy: 0.9932 - val_loss: 0.0252 - val_accuracy: 0.9927
Epoch 5/10
1875/1875 [==============================] - 103s 55ms/step - loss: 0.0187 - accuracy: 0.9946 - val_loss: 0.0295 - val_accuracy: 0.9918
Epoch 6/10
1875/1875 [==============================] - 103s 55ms/step - loss: 0.0163 - accuracy: 0.9953 - val_loss: 0.0324 - val_accuracy: 0.9924
Epoch 7/10
1875/1875 [==============================] - 103s 55ms/step - loss: 0.0156 - accuracy: 0.9956 - val_loss: 0.0401 - val_accuracy: 0.9922
Epoch 8/10
1875/1875 [==============================] - 103s 55ms/step - loss: 0.0121 - accuracy: 0.9966 - val_loss: 0.0391 - val_accuracy: 0.9924
Epoch 9/10
1875/1875 [==============================] - 103s 55ms/step - loss: 0.0107 - accuracy: 0.9970 - val_loss: 0.0359 - val_accuracy: 0.9924
Epoch 10/10
1875/1875 [==============================] - 103s 55ms/step - loss: 0.0121 - accuracy: 0.9969 - val_loss: 0.0445 - val_accuracy: 0.9917
313/313 [==============================] - 5s 16ms/step - loss: 0.0445 - accuracy: 0.9917
Test accuracy: 0.9916999936103821
Time: 1038.7100257873535
```
同时,还会出现两个可视化图表,分别为训练和验证的准确率和损失函数随训练轮数的变化。
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)